refactor: add black formatting

This commit is contained in:
Krrish Dholakia 2023-12-25 14:10:38 +05:30
parent b87d630b0a
commit 4905929de3
156 changed files with 19723 additions and 10869 deletions

View file

@ -1,9 +1,13 @@
repos:
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 3.8.4 # The version of flake8 to use
hooks:
- id: flake8
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/|^litellm/proxy/tests/
additional_dependencies: [flake8-print]
files: litellm/.*\.py
- repo: local

View file

@ -9,33 +9,37 @@ import os
# Define the list of models to benchmark
# select any LLM listed here: https://docs.litellm.ai/docs/providers
models = ['gpt-3.5-turbo', 'claude-2']
models = ["gpt-3.5-turbo", "claude-2"]
# Enter LLM API keys
# https://docs.litellm.ai/docs/providers
os.environ['OPENAI_API_KEY'] = ""
os.environ['ANTHROPIC_API_KEY'] = ""
os.environ["OPENAI_API_KEY"] = ""
os.environ["ANTHROPIC_API_KEY"] = ""
# List of questions to benchmark (replace with your questions)
questions = [
"When will BerriAI IPO?",
"When will LiteLLM hit $100M ARR?"
]
questions = ["When will BerriAI IPO?", "When will LiteLLM hit $100M ARR?"]
# Enter your system prompt here
system_prompt = """
You are LiteLLMs helpful assistant
"""
@click.command()
@click.option('--system-prompt', default="You are a helpful assistant that can answer questions.", help="System prompt for the conversation.")
@click.option(
"--system-prompt",
default="You are a helpful assistant that can answer questions.",
help="System prompt for the conversation.",
)
def main(system_prompt):
for question in questions:
data = [] # Data for the current question
with tqdm(total=len(models)) as pbar:
for model in models:
colored_description = colored(f"Running question: {question} for model: {model}", 'green')
colored_description = colored(
f"Running question: {question} for model: {model}", "green"
)
pbar.set_description(colored_description)
start_time = time.time()
@ -44,35 +48,43 @@ def main(system_prompt):
max_tokens=500,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
{"role": "user", "content": question},
],
)
end = time.time()
total_time = end - start_time
cost = completion_cost(completion_response=response)
raw_response = response['choices'][0]['message']['content']
raw_response = response["choices"][0]["message"]["content"]
data.append({
'Model': colored(model, 'light_blue'),
'Response': raw_response, # Colorize the response
'ResponseTime': colored(f"{total_time:.2f} seconds", "red"),
'Cost': colored(f"${cost:.6f}", 'green'), # Colorize the cost
})
data.append(
{
"Model": colored(model, "light_blue"),
"Response": raw_response, # Colorize the response
"ResponseTime": colored(f"{total_time:.2f} seconds", "red"),
"Cost": colored(f"${cost:.6f}", "green"), # Colorize the cost
}
)
pbar.update(1)
# Separate headers from the data
headers = ['Model', 'Response', 'Response Time (seconds)', 'Cost ($)']
headers = ["Model", "Response", "Response Time (seconds)", "Cost ($)"]
colwidths = [15, 80, 15, 10]
# Create a nicely formatted table for the current question
table = tabulate([list(d.values()) for d in data], headers, tablefmt="grid", maxcolwidths=colwidths)
table = tabulate(
[list(d.values()) for d in data],
headers,
tablefmt="grid",
maxcolwidths=colwidths,
)
# Print the table for the current question
colored_question = colored(question, 'green')
colored_question = colored(question, "green")
click.echo(f"\nBenchmark Results for '{colored_question}':")
click.echo(table) # Display the formatted table
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -1,25 +1,22 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import litellm
from litellm import embedding, completion, completion_cost
from autoevals.llm import *
###################
import litellm
# litellm completion call
question = "which country has the highest population"
response = litellm.completion(
model = "gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": question
}
],
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": question}],
)
print(response)
# use the auto eval Factuality() evaluator
@ -27,9 +24,11 @@ print(response)
print("calling evaluator")
evaluator = Factuality()
result = evaluator(
output=response.choices[0]["message"]["content"], # response from litellm.completion()
output=response.choices[0]["message"][
"content"
], # response from litellm.completion()
expected="India", # expected output
input=question # question passed to litellm.completion
input=question, # question passed to litellm.completion
)
print(result)

View file

@ -7,6 +7,7 @@ from util import handle_error
from litellm import completion
import os, dotenv, time
import json
dotenv.load_dotenv()
# TODO: set your keys in .env or here:
@ -19,47 +20,61 @@ verbose = True
# litellm.caching_with_models = True # CACHING: caching_with_models Keys in the cache are messages + model. - to learn more: https://docs.litellm.ai/docs/caching/
######### PROMPT LOGGING ##########
os.environ["PROMPTLAYER_API_KEY"] = "" # set your promptlayer key here - https://promptlayer.com/
os.environ[
"PROMPTLAYER_API_KEY"
] = "" # set your promptlayer key here - https://promptlayer.com/
# set callbacks
litellm.success_callback = ["promptlayer"]
############ HELPER FUNCTIONS ###################################
def print_verbose(print_statement):
if verbose:
print(print_statement)
app = Flask(__name__)
CORS(app)
@app.route('/')
@app.route("/")
def index():
return 'received!', 200
return "received!", 200
def data_generator(response):
for chunk in response:
yield f"data: {json.dumps(chunk)}\n\n"
@app.route('/chat/completions', methods=["POST"])
@app.route("/chat/completions", methods=["POST"])
def api_completion():
data = request.json
start_time = time.time()
if data.get('stream') == "True":
data['stream'] = True # convert to boolean
if data.get("stream") == "True":
data["stream"] = True # convert to boolean
try:
if "prompt" not in data:
raise ValueError("data needs to have prompt")
data["model"] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct
data[
"model"
] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct
# COMPLETION CALL
system_prompt = "Only respond to questions about code. Say 'I don't know' to anything outside of that."
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": data.pop("prompt")}]
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": data.pop("prompt")},
]
data["messages"] = messages
print(f"data: {data}")
response = completion(**data)
## LOG SUCCESS
end_time = time.time()
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return Response(data_generator(response), mimetype='text/event-stream')
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
return Response(data_generator(response), mimetype="text/event-stream")
except Exception as e:
# call handle_error function
print_verbose(f"Got Error api_completion(): {traceback.format_exc()}")
@ -69,7 +84,8 @@ def api_completion():
return handle_error(data=data)
return response
@app.route('/get_models', methods=["POST"])
@app.route("/get_models", methods=["POST"])
def get_models():
try:
return litellm.model_list
@ -78,7 +94,8 @@ def get_models():
response = {"error": str(e)}
return response, 200
if __name__ == "__main__":
from waitress import serve
serve(app, host="0.0.0.0", port=4000, threads=500)
serve(app, host="0.0.0.0", port=4000, threads=500)

View file

@ -8,17 +8,18 @@ def get_next_url(response):
:param response: response from requests
:return: next url or None
"""
if 'link' not in response.headers:
if "link" not in response.headers:
return None
headers = response.headers
next_url = headers['Link']
next_url = headers["Link"]
print(next_url)
start_index = next_url.find("<")
end_index = next_url.find(">")
return next_url[1:end_index]
def get_models(url):
"""
Function to retrieve all models from paginated endpoint
@ -36,6 +37,7 @@ def get_models(url):
models.extend(payload)
return models
def get_cleaned_models(models):
"""
Function to clean retrieved models
@ -47,8 +49,9 @@ def get_cleaned_models(models):
cleaned_models.append(model["id"])
return cleaned_models
# Get text-generation models
url = 'https://huggingface.co/api/models?filter=text-generation-inference'
url = "https://huggingface.co/api/models?filter=text-generation-inference"
text_generation_models = get_models(url)
cleaned_text_generation_models = get_cleaned_models(text_generation_models)
@ -56,7 +59,7 @@ print(cleaned_text_generation_models)
# Get conversational models
url = 'https://huggingface.co/api/models?filter=conversational'
url = "https://huggingface.co/api/models?filter=conversational"
conversational_models = get_models(url)
cleaned_conversational_models = get_cleaned_models(conversational_models)
@ -69,15 +72,19 @@ def write_to_txt(cleaned_models, filename):
:param cleaned_models: list of cleaned models
:param filename: name of the text file
"""
with open(filename, 'w') as f:
with open(filename, "w") as f:
for item in cleaned_models:
f.write("%s\n" % item)
# Write contents of cleaned_text_generation_models to text_generation_models.txt
write_to_txt(cleaned_text_generation_models, 'huggingface_llms_metadata/hf_text_generation_models.txt')
write_to_txt(
cleaned_text_generation_models,
"huggingface_llms_metadata/hf_text_generation_models.txt",
)
# Write contents of cleaned_conversational_models to conversational_models.txt
write_to_txt(cleaned_conversational_models, 'huggingface_llms_metadata/hf_conversational_models.txt')
write_to_txt(
cleaned_conversational_models,
"huggingface_llms_metadata/hf_conversational_models.txt",
)

View file

@ -1,4 +1,3 @@
import openai
api_base = f"http://0.0.0.0:8000"
@ -8,29 +7,29 @@ openai.api_key = "temp-key"
print(openai.api_base)
print(f'LiteLLM: response from proxy with streaming')
print(f"LiteLLM: response from proxy with streaming")
response = openai.ChatCompletion.create(
model="ollama/llama2",
messages = [
messages=[
{
"role": "user",
"content": "this is a test request, acknowledge that you got it"
"content": "this is a test request, acknowledge that you got it",
}
],
stream=True
stream=True,
)
for chunk in response:
print(f'LiteLLM: streaming response from proxy {chunk}')
print(f"LiteLLM: streaming response from proxy {chunk}")
response = openai.ChatCompletion.create(
model="ollama/llama2",
messages = [
messages=[
{
"role": "user",
"content": "this is a test request, acknowledge that you got it"
"content": "this is a test request, acknowledge that you got it",
}
]
],
)
print(f'LiteLLM: response from proxy {response}')
print(f"LiteLLM: response from proxy {response}")

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router
import litellm
litellm.set_verbose=False
litellm.set_verbose = False
os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments
model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
}
}]
},
},
]
router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"]
file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = []
for file_path in file_paths:
try:
print(file_path)
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
content = file.read()
questions.append(content)
except FileNotFoundError as e:
@ -59,7 +68,6 @@ for file_path in file_paths:
# print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
@ -74,10 +82,18 @@ def make_openai_completion(question):
try:
start_time = time.time()
import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000"
) # base_url="http://0.0.0.0:8000",
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}],
messages=[
{
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
}
],
)
print(response)
end_time = time.time()
@ -92,11 +108,10 @@ def make_openai_completion(question):
except Exception as e:
# Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file:
error_log_file.write(
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
return None
# Number of concurrent calls (you can adjust this)
concurrent_calls = 100
@ -133,4 +148,3 @@ with open("request_log.txt", "r") as log_file:
with open("error_log.txt", "r") as error_log_file:
print("\nError Log:\n", error_log_file.read())

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router
import litellm
litellm.set_verbose=False
litellm.set_verbose = False
# os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments
model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
}
}]
},
},
]
router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"]
file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = []
for file_path in file_paths:
try:
print(file_path)
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
content = file.read()
questions.append(content)
except FileNotFoundError as e:
@ -59,7 +68,6 @@ for file_path in file_paths:
# print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
@ -76,9 +84,12 @@ def make_openai_completion(question):
import requests
data = {
'model': 'gpt-3.5-turbo',
'messages': [
{'role': 'system', 'content': f'You are a helpful assistant. Answer this question{question}'},
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
},
],
}
response = requests.post("http://0.0.0.0:8000/queue/request", json=data)
@ -107,7 +118,9 @@ def make_openai_completion(question):
)
break
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}")
print(
f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}"
)
time.sleep(0.5)
except Exception as e:
print("got exception in polling", e)
@ -117,11 +130,10 @@ def make_openai_completion(question):
except Exception as e:
# Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file:
error_log_file.write(
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
return None
# Number of concurrent calls (you can adjust this)
concurrent_calls = 10
@ -152,4 +164,3 @@ print(f"Load test Summary:")
print(f"Total Requests: {concurrent_calls}")
print(f"Successful Calls: {successful_calls}")
print(f"Failed Calls: {failed_calls}")

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router
import litellm
litellm.set_verbose=False
litellm.set_verbose = False
os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments
model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
}
}]
},
},
]
router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"]
file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = []
for file_path in file_paths:
try:
print(file_path)
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
content = file.read()
questions.append(content)
except FileNotFoundError as e:
@ -59,7 +68,6 @@ for file_path in file_paths:
# print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
@ -75,7 +83,12 @@ def make_openai_completion(question):
start_time = time.time()
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}],
messages=[
{
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
}
],
)
print(response)
end_time = time.time()
@ -90,11 +103,10 @@ def make_openai_completion(question):
except Exception as e:
# Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file:
error_log_file.write(
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
return None
# Number of concurrent calls (you can adjust this)
concurrent_calls = 150
@ -131,4 +143,3 @@ with open("request_log.txt", "r") as log_file:
with open("error_log.txt", "r") as error_log_file:
print("\nError Log:\n", error_log_file.read())

View file

@ -9,9 +9,15 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = []
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Union[str, Callable]] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
_async_input_callback: List[
Callable
] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[
Union[str, Callable]
] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[
Callable
] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
email: Optional[
@ -44,12 +50,80 @@ use_client: bool = False
logging: bool = True
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
cache: Optional[
Cache
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers
_openai_completion_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
_litellm_completion_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
_openai_completion_params = [
"functions",
"function_call",
"temperature",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"request_timeout",
"api_base",
"api_version",
"api_key",
"deployment_id",
"organization",
"base_url",
"default_headers",
"timeout",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
]
_litellm_completion_params = [
"metadata",
"acompletion",
"caching",
"mock_response",
"api_key",
"api_version",
"api_base",
"force_timeout",
"logger_fn",
"verbose",
"custom_llm_provider",
"litellm_logging_obj",
"litellm_call_id",
"use_client",
"id",
"fallbacks",
"azure",
"headers",
"model_list",
"num_retries",
"context_window_fallback_dict",
"roles",
"final_prompt_value",
"bos_token",
"eos_token",
"request_timeout",
"complete_response",
"self",
"client",
"rpm",
"tpm",
"input_cost_per_token",
"output_cost_per_token",
"hf_model_name",
"model_info",
"proxy_server_request",
"preset_cache_key",
]
_current_cost = 0 # private variable, used if max budget is set
error_logs: Dict = {}
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
@ -66,23 +140,35 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None
allowed_fails: int = 0
####### SECRET MANAGERS #####################
secret_manager_client: Optional[Any] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
secret_manager_client: Optional[
Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
#############################################
def get_model_cost_map(url: str):
try:
with requests.get(url, timeout=5) as response: # set a 5 second timeout for the get request
with requests.get(
url, timeout=5
) as response: # set a 5 second timeout for the get request
response.raise_for_status() # Raise an exception if the request is unsuccessful
content = response.json()
return content
except Exception as e:
import importlib.resources
import json
with importlib.resources.open_text("litellm", "model_prices_and_context_window_backup.json") as f:
with importlib.resources.open_text(
"litellm", "model_prices_and_context_window_backup.json"
) as f:
content = json.load(f)
return content
model_cost = get_model_cost_map(url=model_cost_map_url)
custom_prompt_dict:Dict[str, dict] = {}
custom_prompt_dict: Dict[str, dict] = {}
####### THREAD-SPECIFIC DATA ###################
class MyLocal(threading.local):
def __init__(self):
@ -123,39 +209,39 @@ bedrock_models: List = []
deepinfra_models: List = []
perplexity_models: List = []
for key, value in model_cost.items():
if value.get('litellm_provider') == 'openai':
if value.get("litellm_provider") == "openai":
open_ai_chat_completion_models.append(key)
elif value.get('litellm_provider') == 'text-completion-openai':
elif value.get("litellm_provider") == "text-completion-openai":
open_ai_text_completion_models.append(key)
elif value.get('litellm_provider') == 'cohere':
elif value.get("litellm_provider") == "cohere":
cohere_models.append(key)
elif value.get('litellm_provider') == 'anthropic':
elif value.get("litellm_provider") == "anthropic":
anthropic_models.append(key)
elif value.get('litellm_provider') == 'openrouter':
elif value.get("litellm_provider") == "openrouter":
openrouter_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-text-models':
elif value.get("litellm_provider") == "vertex_ai-text-models":
vertex_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
elif value.get("litellm_provider") == "vertex_ai-code-text-models":
vertex_code_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-language-models':
elif value.get("litellm_provider") == "vertex_ai-language-models":
vertex_language_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-vision-models':
elif value.get("litellm_provider") == "vertex_ai-vision-models":
vertex_vision_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-chat-models':
elif value.get("litellm_provider") == "vertex_ai-chat-models":
vertex_chat_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models':
elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
vertex_code_chat_models.append(key)
elif value.get('litellm_provider') == 'ai21':
elif value.get("litellm_provider") == "ai21":
ai21_models.append(key)
elif value.get('litellm_provider') == 'nlp_cloud':
elif value.get("litellm_provider") == "nlp_cloud":
nlp_cloud_models.append(key)
elif value.get('litellm_provider') == 'aleph_alpha':
elif value.get("litellm_provider") == "aleph_alpha":
aleph_alpha_models.append(key)
elif value.get('litellm_provider') == 'bedrock':
elif value.get("litellm_provider") == "bedrock":
bedrock_models.append(key)
elif value.get('litellm_provider') == 'deepinfra':
elif value.get("litellm_provider") == "deepinfra":
deepinfra_models.append(key)
elif value.get('litellm_provider') == 'perplexity':
elif value.get("litellm_provider") == "perplexity":
perplexity_models.append(key)
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
@ -163,16 +249,11 @@ openai_compatible_endpoints: List = [
"api.perplexity.ai",
"api.endpoints.anyscale.com/v1",
"api.deepinfra.com/v1/openai",
"api.mistral.ai/v1"
"api.mistral.ai/v1",
]
# this is maintained for Exception Mapping
openai_compatible_providers: List = [
"anyscale",
"mistral",
"deepinfra",
"perplexity"
]
openai_compatible_providers: List = ["anyscale", "mistral", "deepinfra", "perplexity"]
# well supported replicate llms
@ -209,23 +290,18 @@ huggingface_models: List = [
together_ai_models: List = [
# llama llms - chat
"togethercomputer/llama-2-70b-chat",
# llama llms - language / instruct
"togethercomputer/llama-2-70b",
"togethercomputer/LLaMA-2-7B-32K",
"togethercomputer/Llama-2-7B-32K-Instruct",
"togethercomputer/llama-2-7b",
# falcon llms
"togethercomputer/falcon-40b-instruct",
"togethercomputer/falcon-7b-instruct",
# alpaca
"togethercomputer/alpaca-7b",
# chat llms
"HuggingFaceH4/starchat-alpha",
# code llms
"togethercomputer/CodeLlama-34b",
"togethercomputer/CodeLlama-34b-Instruct",
@ -234,29 +310,27 @@ together_ai_models: List = [
"NumbersStation/nsql-llama-2-7B",
"WizardLM/WizardCoder-15B-V1.0",
"WizardLM/WizardCoder-Python-34B-V1.0",
# language llms
"NousResearch/Nous-Hermes-Llama2-13b",
"Austism/chronos-hermes-13b",
"upstage/SOLAR-0-70b-16bit",
"WizardLM/WizardLM-70B-V1.0",
] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
baseten_models: List = ["qvv0xeq", "q841o8w", "31dxrj3"] # FALCON 7B # WizardLM # Mosaic ML
baseten_models: List = [
"qvv0xeq",
"q841o8w",
"31dxrj3",
] # FALCON 7B # WizardLM # Mosaic ML
petals_models = [
"petals-team/StableBeluga2",
]
ollama_models = [
"llama2"
]
ollama_models = ["llama2"]
maritalk_models = [
"maritalk"
]
maritalk_models = ["maritalk"]
model_list = (
open_ai_chat_completion_models
@ -327,7 +401,7 @@ models_by_provider: dict = {
"ollama": ollama_models,
"deepinfra": deepinfra_models,
"perplexity": perplexity_models,
"maritalk": maritalk_models
"maritalk": maritalk_models,
}
# mapping for those models which have larger equivalents
@ -362,15 +436,18 @@ cohere_embedding_models: List = [
"embed-english-light-v2.0",
"embed-multilingual-v2.0",
]
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]
bedrock_embedding_models: List = [
"amazon.titan-embed-text-v1",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
all_embedding_models = (
open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
)
####### IMAGE GENERATION MODELS ###################
openai_image_generation_models = [
"dall-e-2",
"dall-e-3"
]
openai_image_generation_models = ["dall-e-2", "dall-e-3"]
from .timeout import timeout
@ -398,7 +475,7 @@ from .utils import (
decode,
_calculate_retry_after,
_should_retry,
get_secret
get_secret,
)
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig
@ -415,7 +492,13 @@ from .llms.vertex_ai import VertexAIConfig
from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig, AmazonLlamaConfig
from .llms.bedrock import (
AmazonTitanConfig,
AmazonAI21Config,
AmazonAnthropicConfig,
AmazonCohereConfig,
AmazonLlamaConfig,
)
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig
from .main import * # type: ignore
@ -434,7 +517,7 @@ from .exceptions import (
Timeout,
APIConnectionError,
APIResponseValidationError,
UnprocessableEntityError
UnprocessableEntityError,
)
from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server

View file

@ -1,5 +1,6 @@
set_verbose = False
def print_verbose(print_statement):
try:
if set_verbose:

View file

@ -13,6 +13,7 @@ import inspect
import redis, litellm
from typing import List, Optional
def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis)
@ -23,23 +24,17 @@ def _get_redis_kwargs():
"retry",
}
include_args = ["url"]
include_args = [
"url"
]
available_args = [
x for x in arg_spec.args if x not in exclude_args
] + include_args
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
return available_args
def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"
return {
f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()
}
return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
def _redis_kwargs_from_environment():
@ -58,14 +53,19 @@ def get_redis_url_from_environment():
return os.environ["REDIS_URL"]
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
raise ValueError("Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis.")
raise ValueError(
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
)
if "REDIS_PASSWORD" in os.environ:
redis_password = f":{os.environ['REDIS_PASSWORD']}@"
else:
redis_password = ""
return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
return (
f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
)
def get_redis_client(**env_overrides):
### check if "os.environ/<key-name>" passed in
@ -80,14 +80,14 @@ def get_redis_client(**env_overrides):
**env_overrides,
}
if "url" in redis_kwargs and redis_kwargs['url'] is not None:
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop("host", None)
redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None)
return redis.Redis.from_url(**redis_kwargs)
elif "host" not in redis_kwargs or redis_kwargs['host'] is None:
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis.Redis(**redis_kwargs)

View file

@ -4,8 +4,14 @@ from litellm.utils import ModelResponse
import requests, threading
from typing import Optional, Union, Literal
class BudgetManager:
def __init__(self, project_name: str, client_type: str = "local", api_base: Optional[str] = None):
def __init__(
self,
project_name: str,
client_type: str = "local",
api_base: Optional[str] = None,
):
self.client_type = client_type
self.project_name = project_name
self.api_base = api_base or "https://api.litellm.ai"
@ -16,6 +22,7 @@ class BudgetManager:
try:
if litellm.set_verbose:
import logging
logging.info(print_statement)
except:
pass
@ -25,7 +32,7 @@ class BudgetManager:
# Check if user dict file exists
if os.path.isfile("user_cost.json"):
# Load the user dict
with open("user_cost.json", 'r') as json_file:
with open("user_cost.json", "r") as json_file:
self.user_dict = json.load(json_file)
else:
self.print_verbose("User Dictionary not found!")
@ -34,40 +41,55 @@ class BudgetManager:
elif self.client_type == "hosted":
# Load the user_dict from hosted db
url = self.api_base + "/get_budget"
headers = {'Content-Type': 'application/json'}
data = {
'project_name' : self.project_name
}
headers = {"Content-Type": "application/json"}
data = {"project_name": self.project_name}
response = requests.post(url, headers=headers, json=data)
response = response.json()
if response["status"] == "error":
self.user_dict = {} # assume this means the user dict hasn't been stored yet
self.user_dict = (
{}
) # assume this means the user dict hasn't been stored yet
else:
self.user_dict = response["data"]
def create_budget(self, total_budget: float, user: str, duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None, created_at: float = time.time()):
def create_budget(
self,
total_budget: float,
user: str,
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
created_at: float = time.time(),
):
self.user_dict[user] = {"total_budget": total_budget}
if duration is None:
return self.user_dict[user]
if duration == 'daily':
if duration == "daily":
duration_in_days = 1
elif duration == 'weekly':
elif duration == "weekly":
duration_in_days = 7
elif duration == 'monthly':
elif duration == "monthly":
duration_in_days = 28
elif duration == 'yearly':
elif duration == "yearly":
duration_in_days = 365
else:
raise ValueError("""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""")
self.user_dict[user] = {"total_budget": total_budget, "duration": duration_in_days, "created_at": created_at, "last_updated_at": created_at}
raise ValueError(
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
)
self.user_dict[user] = {
"total_budget": total_budget,
"duration": duration_in_days,
"created_at": created_at,
"last_updated_at": created_at,
}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return self.user_dict[user]
def projected_cost(self, model: str, messages: list, user: str):
text = "".join(message["content"] for message in messages)
prompt_tokens = litellm.token_counter(model=model, text=text)
prompt_cost, _ = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=0)
prompt_cost, _ = litellm.cost_per_token(
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
)
current_cost = self.user_dict[user].get("current_cost", 0)
projected_cost = prompt_cost + current_cost
return projected_cost
@ -75,28 +97,53 @@ class BudgetManager:
def get_total_budget(self, user: str):
return self.user_dict[user]["total_budget"]
def update_cost(self, user: str, completion_obj: Optional[ModelResponse] = None, model: Optional[str] = None, input_text: Optional[str] = None, output_text: Optional[str] = None):
def update_cost(
self,
user: str,
completion_obj: Optional[ModelResponse] = None,
model: Optional[str] = None,
input_text: Optional[str] = None,
output_text: Optional[str] = None,
):
if model and input_text and output_text:
prompt_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": input_text}])
completion_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": output_text}])
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
prompt_tokens = litellm.token_counter(
model=model, messages=[{"role": "user", "content": input_text}]
)
completion_tokens = litellm.token_counter(
model=model, messages=[{"role": "user", "content": output_text}]
)
(
prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar,
) = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
elif completion_obj:
cost = litellm.completion_cost(completion_response=completion_obj)
model = completion_obj['model'] # if this throws an error try, model = completion_obj['model']
model = completion_obj[
"model"
] # if this throws an error try, model = completion_obj['model']
else:
raise ValueError("Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager")
raise ValueError(
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
)
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get("current_cost", 0)
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
"current_cost", 0
)
if "model_cost" in self.user_dict[user]:
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user]["model_cost"].get(model, 0)
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
"model_cost"
].get(model, 0)
else:
self.user_dict[user]["model_cost"] = {model: cost}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return {"user": self.user_dict[user]}
def get_current_cost(self, user):
return self.user_dict[user].get("current_cost", 0)
@ -135,7 +182,9 @@ class BudgetManager:
self.reset_on_duration(user)
def _save_data_thread(self):
thread = threading.Thread(target=self.save_data) # [Non-Blocking]: saves data without blocking execution
thread = threading.Thread(
target=self.save_data
) # [Non-Blocking]: saves data without blocking execution
thread.start()
def save_data(self):
@ -143,16 +192,15 @@ class BudgetManager:
import json
# save the user dict
with open("user_cost.json", 'w') as json_file:
json.dump(self.user_dict, json_file, indent=4) # Indent for pretty formatting
with open("user_cost.json", "w") as json_file:
json.dump(
self.user_dict, json_file, indent=4
) # Indent for pretty formatting
return {"status": "success"}
elif self.client_type == "hosted":
url = self.api_base + "/set_budget"
headers = {'Content-Type': 'application/json'}
data = {
'project_name' : self.project_name,
"user_dict": self.user_dict
}
headers = {"Content-Type": "application/json"}
data = {"project_name": self.project_name, "user_dict": self.user_dict}
response = requests.post(url, headers=headers, json=data)
response = response.json()
return response

View file

@ -12,6 +12,7 @@ import time, logging
import json, traceback, ast
from typing import Optional, Literal, List
def print_verbose(print_statement):
try:
if litellm.set_verbose:
@ -19,6 +20,7 @@ def print_verbose(print_statement):
except:
pass
class BaseCache:
def set_cache(self, key, value, **kwargs):
raise NotImplementedError
@ -60,6 +62,7 @@ class InMemoryCache(BaseCache):
class RedisCache(BaseCache):
def __init__(self, host=None, port=None, password=None, **kwargs):
import redis
# if users don't provider one, use the default litellm cache
from ._redis import get_redis_client
@ -88,12 +91,18 @@ class RedisCache(BaseCache):
try:
print_verbose(f"Get Redis Cache: key: {key}")
cached_response = self.redis_client.get(key)
print_verbose(f"Got Redis Cache: key: {key}, cached_response {cached_response}")
print_verbose(
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
)
if cached_response != None:
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string
cached_response = cached_response.decode(
"utf-8"
) # Convert bytes to string
try:
cached_response = json.loads(cached_response) # Convert string to dictionary
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
@ -105,13 +114,19 @@ class RedisCache(BaseCache):
def flush_cache(self):
self.redis_client.flushall()
class DualCache(BaseCache):
"""
This updates both Redis and an in-memory cache simultaneously.
When data is updated or inserted, it is written to both the in-memory cache + Redis.
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
"""
def __init__(self, in_memory_cache: Optional[InMemoryCache] =None, redis_cache: Optional[RedisCache] =None) -> None:
def __init__(
self,
in_memory_cache: Optional[InMemoryCache] = None,
redis_cache: Optional[RedisCache] = None,
) -> None:
super().__init__()
# If in_memory_cache is not provided, use the default InMemoryCache
self.in_memory_cache = in_memory_cache or InMemoryCache()
@ -162,6 +177,7 @@ class DualCache(BaseCache):
if self.redis_cache is not None:
self.redis_cache.flush_cache()
#### LiteLLM.Completion / Embedding Cache ####
class Cache:
def __init__(
@ -170,8 +186,10 @@ class Cache:
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"],
**kwargs
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
**kwargs,
):
"""
Initializes the cache based on the given type.
@ -222,8 +240,27 @@ class Cache:
return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]
embedding_only_kwargs = ["input", "encoding_format"] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
completion_kwargs = [
"model",
"messages",
"temperature",
"top_p",
"n",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
]
embedding_only_kwargs = [
"input",
"encoding_format",
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
combined_kwargs = completion_kwargs + embedding_only_kwargs
@ -255,19 +292,30 @@ class Cache:
if model_group in group:
caching_group = group
break
param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"]
param_value = (
caching_group or model_group or kwargs[param]
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
else:
if kwargs[param] is None:
continue # ignore None params
param_value = kwargs[param]
cache_key+= f"{str(param)}: {str(param_value)}"
cache_key += f"{str(param)}: {str(param_value)}"
print_verbose(f"\nCreated cache key: {cache_key}")
return cache_key
def generate_streaming_content(self, content):
chunk_size = 5 # Adjust the chunk size as needed
for i in range(0, len(content), chunk_size):
yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i + chunk_size]}}]}
yield {
"choices": [
{
"delta": {
"role": "assistant",
"content": content[i : i + chunk_size],
}
}
]
}
time.sleep(0.02)
def get_cache(self, *args, **kwargs):

View file

@ -48,7 +48,6 @@
# # print("\033[92mLiteLLM: Switched on Redis caching\033[0m")
# def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]='/app/config.yaml'):
# config = {}
# server_settings = {}

View file

@ -20,7 +20,7 @@ from openai import (
APITimeoutError,
APIConnectionError,
APIResponseValidationError,
UnprocessableEntityError
UnprocessableEntityError,
)
import httpx
@ -32,11 +32,10 @@ class AuthenticationError(AuthenticationError): # type: ignore
self.llm_provider = llm_provider
self.model = model
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
# raise when invalid models passed, example gpt-8
class NotFoundError(NotFoundError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
@ -45,9 +44,7 @@ class NotFoundError(NotFoundError): # type: ignore
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
@ -58,11 +55,10 @@ class BadRequestError(BadRequestError): # type: ignore
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 422
@ -70,11 +66,10 @@ class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 408
@ -86,6 +81,7 @@ class Timeout(APITimeoutError): # type: ignore
request=request
) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 429
@ -93,11 +89,10 @@ class RateLimitError(RateLimitError): # type: ignore
self.llm_provider = llm_provider
self.modle = model
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
@ -109,9 +104,10 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
message=self.message,
model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore
response=response
response=response,
) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(APIStatusError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 503
@ -119,24 +115,21 @@ class ServiceUnavailableError(APIStatusError): # type: ignore
self.llm_provider = llm_provider
self.model = model
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(APIError): # type: ignore
def __init__(self, status_code, message, llm_provider, model, request: httpx.Request):
def __init__(
self, status_code, message, llm_provider, model, request: httpx.Request
):
self.status_code = status_code
self.message = message
self.llm_provider = llm_provider
self.model = model
super().__init__(
self.message,
request=request, # type: ignore
body=None
)
super().__init__(self.message, request=request, body=None) # type: ignore
# raised if an invalid request (not get, delete, put, post) is made
class APIConnectionError(APIConnectionError): # type: ignore
@ -145,10 +138,8 @@ class APIConnectionError(APIConnectionError): # type: ignore
self.llm_provider = llm_provider
self.model = model
self.status_code = 500
super().__init__(
message=self.message,
request=request
)
super().__init__(message=self.message, request=request)
# raised if an invalid request (not get, delete, put, post) is made
class APIResponseValidationError(APIResponseValidationError): # type: ignore
@ -158,11 +149,8 @@ class APIResponseValidationError(APIResponseValidationError): # type: ignore
self.model = model
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request)
super().__init__(
response=response,
body=None,
message=message
)
super().__init__(response=response, body=None, message=message)
class OpenAIError(OpenAIError): # type: ignore
def __init__(self, original_exception):
@ -176,6 +164,7 @@ class OpenAIError(OpenAIError): # type: ignore
)
self.llm_provider = "openai"
class BudgetExceededError(Exception):
def __init__(self, current_cost, max_budget):
self.current_cost = current_cost
@ -183,6 +172,7 @@ class BudgetExceededError(Exception):
message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
super().__init__(message)
## DEPRECATED ##
class InvalidRequestError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider):

View file

@ -5,6 +5,7 @@ import requests
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
@ -47,10 +48,19 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
"""
Control the modify incoming / outgoung data before calling the model
"""
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal["completion", "embeddings"],
):
pass
async def async_post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
pass
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
@ -63,14 +73,14 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
callback_func(
kwargs,
)
print_verbose(
f"Custom Logger - model call details: {kwargs}"
)
print_verbose(f"Custom Logger - model call details: {kwargs}")
except:
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
async def async_log_input_event(
self, model, messages, kwargs, print_verbose, callback_func
):
try:
kwargs["model"] = model
kwargs["messages"] = messages
@ -78,15 +88,14 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
await callback_func(
kwargs,
)
print_verbose(
f"Custom Logger - model call details: {kwargs}"
)
print_verbose(f"Custom Logger - model call details: {kwargs}")
except:
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
def log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
):
# Method definition
try:
kwargs["log_event_type"] = "post_api_call"
@ -96,15 +105,15 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
start_time,
end_time,
)
print_verbose(
f"Custom Logger - final response object: {response_obj}"
)
print_verbose(f"Custom Logger - final response object: {response_obj}")
except:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass
async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
async def async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
):
# Method definition
try:
kwargs["log_event_type"] = "post_api_call"
@ -114,9 +123,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
start_time,
end_time,
)
print_verbose(
f"Custom Logger - final response object: {response_obj}"
)
print_verbose(f"Custom Logger - final response object: {response_obj}")
except:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")

View file

@ -10,19 +10,28 @@ import datetime, subprocess, sys
import litellm, uuid
from litellm._logging import print_verbose
class DyanmoDBLogger:
# Class variables or attributes
def __init__(self):
# Instance variables
import boto3
self.dynamodb = boto3.resource('dynamodb', region_name=os.environ["AWS_REGION_NAME"])
self.dynamodb = boto3.resource(
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
)
if litellm.dynamodb_table_name is None:
raise ValueError("LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`")
raise ValueError(
"LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`"
)
self.table_name = litellm.dynamodb_table_name
async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
async def _async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose
):
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
try:
print_verbose(
@ -32,7 +41,9 @@ class DyanmoDBLogger:
# construct payload to send to DynamoDB
# follows the same params as langfuse.py
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
@ -51,7 +62,7 @@ class DyanmoDBLogger:
"messages": messages,
"response": response_obj,
"usage": usage,
"metadata": metadata
"metadata": metadata,
}
# Ensure everything in the payload is converted to str
@ -62,7 +73,6 @@ class DyanmoDBLogger:
# non blocking if it can't cast to a str
pass
print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
# put data in dyanmo DB

View file

@ -64,7 +64,9 @@ class LangFuseLogger:
# end of processing langfuse ########################
input = prompt
output = response_obj["choices"][0]["message"].json()
print(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}")
print(
f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}"
)
self._log_langfuse_v2(
user_id,
metadata,
@ -171,7 +173,6 @@ class LangFuseLogger:
user_id=user_id,
)
trace.generation(
name=metadata.get("generation_name", "litellm-completion"),
startTime=start_time,

View file

@ -8,12 +8,12 @@ from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
class LangsmithLogger:
# Class variables or attributes
def __init__(self):
self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
# Method definition
# inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb
@ -26,7 +26,9 @@ class LangsmithLogger:
# if not set litellm will use default project_name = litellm-completion, run_name = LLMRun
project_name = metadata.get("project_name", "litellm-completion")
run_name = metadata.get("run_name", "LLMRun")
print_verbose(f"Langsmith Logging - project_name: {project_name}, run_name {run_name}")
print_verbose(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
try:
print_verbose(
f"Langsmith Logging - Enters logging function for model {kwargs}"
@ -34,6 +36,7 @@ class LangsmithLogger:
import requests
import datetime
from datetime import timezone
try:
start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat()
end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat()
@ -45,7 +48,7 @@ class LangsmithLogger:
new_kwargs = {}
for key in kwargs:
value = kwargs[key]
if key == "start_time" or key =="end_time":
if key == "start_time" or key == "end_time":
pass
elif type(value) != dict:
new_kwargs[key] = value
@ -55,17 +58,13 @@ class LangsmithLogger:
json={
"name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": {
**new_kwargs
},
"inputs": {**new_kwargs},
"outputs": response_obj.json(),
"session_name": project_name,
"start_time": start_time,
"end_time": end_time,
},
headers={
"x-api-key": self.langsmith_api_key
}
headers={"x-api-key": self.langsmith_api_key},
)
print_verbose(
f"Langsmith Layer Logging - final response object: {response_obj}"

View file

@ -1,6 +1,7 @@
import requests, traceback, json, os
import types
class LiteDebugger:
user_email = None
dashboard_url = None
@ -12,9 +13,15 @@ class LiteDebugger:
def validate_environment(self, email):
try:
self.user_email = (email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL"))
if self.user_email == None: # if users are trying to use_client=True but token not set
raise ValueError("litellm.use_client = True but no token or email passed. Please set it in litellm.token")
self.user_email = (
email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")
)
if (
self.user_email == None
): # if users are trying to use_client=True but token not set
raise ValueError(
"litellm.use_client = True but no token or email passed. Please set it in litellm.token"
)
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
try:
print(
@ -42,7 +49,9 @@ class LiteDebugger:
litellm_params,
optional_params,
):
print_verbose(f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}")
print_verbose(
f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}"
)
try:
print_verbose(
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
@ -56,7 +65,11 @@ class LiteDebugger:
updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
if call_type == "embedding":
for message in messages: # assuming the input is a list as required by the embedding function
for (
message
) in (
messages
): # assuming the input is a list as required by the embedding function
litellm_data_obj = {
"model": model,
"messages": [{"role": "user", "content": message}],
@ -79,7 +92,9 @@ class LiteDebugger:
elif call_type == "completion":
litellm_data_obj = {
"model": model,
"messages": messages if isinstance(messages, list) else [{"role": "user", "content": messages}],
"messages": messages
if isinstance(messages, list)
else [{"role": "user", "content": messages}],
"end_user": end_user,
"status": "initiated",
"litellm_call_id": litellm_call_id,
@ -95,20 +110,30 @@ class LiteDebugger:
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
print_verbose(f"LiteDebugger: completion api response - {response.text}")
print_verbose(
f"LiteDebugger: completion api response - {response.text}"
)
except:
print_verbose(
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
)
pass
def post_call_log_event(self, original_response, litellm_call_id, print_verbose, call_type, stream):
print_verbose(f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}")
def post_call_log_event(
self, original_response, litellm_call_id, print_verbose, call_type, stream
):
print_verbose(
f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}"
)
try:
if call_type == "embedding":
litellm_data_obj = {
"status": "received",
"additional_details": {"original_response": str(original_response["data"][0]["embedding"][:5])}, # don't store the entire vector
"additional_details": {
"original_response": str(
original_response["data"][0]["embedding"][:5]
)
}, # don't store the entire vector
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
}
@ -122,7 +147,11 @@ class LiteDebugger:
elif call_type == "completion" and stream:
litellm_data_obj = {
"status": "received",
"additional_details": {"original_response": "Streamed response" if isinstance(original_response, types.GeneratorType) else original_response},
"additional_details": {
"original_response": "Streamed response"
if isinstance(original_response, types.GeneratorType)
else original_response
},
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
}
@ -147,9 +176,11 @@ class LiteDebugger:
litellm_call_id,
print_verbose,
call_type,
stream = False
stream=False,
):
print_verbose(f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}")
print_verbose(
f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}"
)
try:
print_verbose(
f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}"

View file

@ -18,19 +18,17 @@ class PromptLayerLogger:
# Method definition
try:
new_kwargs = {}
new_kwargs['model'] = kwargs['model']
new_kwargs['messages'] = kwargs['messages']
new_kwargs["model"] = kwargs["model"]
new_kwargs["messages"] = kwargs["messages"]
# add kwargs["optional_params"] to new_kwargs
for optional_param in kwargs["optional_params"]:
new_kwargs[optional_param] = kwargs["optional_params"][optional_param]
print_verbose(
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
)
request_response = requests.post(
"https://api.promptlayer.com/rest/track-request",
json={
@ -62,10 +60,12 @@ class PromptLayerLogger:
json={
"request_id": response_json["request_id"],
"api_key": self.key,
"metadata": kwargs["litellm_params"]["metadata"]
"metadata": kwargs["litellm_params"]["metadata"],
},
)
print_verbose(f"Prompt Layer Logging: success - metadata post response object: {response.text}")
print_verbose(
f"Prompt Layer Logging: success - metadata post response object: {response.text}"
)
except:
print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")

View file

@ -9,6 +9,7 @@ import traceback
import datetime, subprocess, sys
import litellm
class Supabase:
# Class variables or attributes
supabase_table_name = "request_logs"

View file

@ -2,6 +2,7 @@ class TraceloopLogger:
def __init__(self):
from traceloop.sdk.tracing.tracing import TracerWrapper
from traceloop.sdk import Traceloop
Traceloop.init(app_name="Litellm-Server", disable_batch=True)
self.tracer_wrapper = TracerWrapper()
@ -29,15 +30,18 @@ class TraceloopLogger:
)
if "stop" in optional_params:
span.set_attribute(
SpanAttributes.LLM_CHAT_STOP_SEQUENCES, optional_params.get("stop")
SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
optional_params.get("stop"),
)
if "frequency_penalty" in optional_params:
span.set_attribute(
SpanAttributes.LLM_FREQUENCY_PENALTY, optional_params.get("frequency_penalty")
SpanAttributes.LLM_FREQUENCY_PENALTY,
optional_params.get("frequency_penalty"),
)
if "presence_penalty" in optional_params:
span.set_attribute(
SpanAttributes.LLM_PRESENCE_PENALTY, optional_params.get("presence_penalty")
SpanAttributes.LLM_PRESENCE_PENALTY,
optional_params.get("presence_penalty"),
)
if "top_p" in optional_params:
span.set_attribute(
@ -45,7 +49,10 @@ class TraceloopLogger:
)
if "tools" in optional_params or "functions" in optional_params:
span.set_attribute(
SpanAttributes.LLM_REQUEST_FUNCTIONS, optional_params.get("tools", optional_params.get("functions"))
SpanAttributes.LLM_REQUEST_FUNCTIONS,
optional_params.get(
"tools", optional_params.get("functions")
),
)
if "user" in optional_params:
span.set_attribute(
@ -53,7 +60,8 @@ class TraceloopLogger:
)
if "max_tokens" in optional_params:
span.set_attribute(
SpanAttributes.LLM_REQUEST_MAX_TOKENS, kwargs.get("max_tokens")
SpanAttributes.LLM_REQUEST_MAX_TOKENS,
kwargs.get("max_tokens"),
)
if "temperature" in optional_params:
span.set_attribute(

View file

@ -1,4 +1,4 @@
imported_openAIResponse=True
imported_openAIResponse = True
try:
import io
import logging
@ -12,14 +12,11 @@ try:
else:
from typing_extensions import Literal, Protocol
logger = logging.getLogger(__name__)
K = TypeVar("K", bound=str)
V = TypeVar("V")
class OpenAIResponse(Protocol[K, V]): # type: ignore
# contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"]
@ -30,7 +27,6 @@ try:
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
... # pragma: no cover
class OpenAIRequestResponseResolver:
def __call__(
self,
@ -44,7 +40,9 @@ try:
elif response["object"] == "text_completion":
return self._resolve_completion(request, response, time_elapsed)
elif response["object"] == "chat.completion":
return self._resolve_chat_completion(request, response, time_elapsed)
return self._resolve_chat_completion(
request, response, time_elapsed
)
else:
logger.info(f"Unknown OpenAI response object: {response['object']}")
except Exception as e:
@ -113,7 +111,8 @@ try:
"""Resolves the request and response objects for `openai.Completion`."""
request_str = f"\n\n**Prompt**: {request['prompt']}\n"
choices = [
f"\n\n**Completion**: {choice['text']}\n" for choice in response["choices"]
f"\n\n**Completion**: {choice['text']}\n"
for choice in response["choices"]
]
return self._request_response_result_to_trace(
@ -167,9 +166,9 @@ try:
]
trace = self.results_to_trace_tree(request, response, results, time_elapsed)
return trace
except:
imported_openAIResponse=False
except:
imported_openAIResponse = False
#### What this does ####
@ -182,15 +181,20 @@ from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
class WeightsBiasesLogger:
# Class variables or attributes
def __init__(self):
try:
import wandb
except:
raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m")
if imported_openAIResponse==False:
raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m")
raise Exception(
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
)
if imported_openAIResponse == False:
raise Exception(
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
)
self.resolver = OpenAIRequestResponseResolver()
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
@ -198,13 +202,13 @@ class WeightsBiasesLogger:
import wandb
try:
print_verbose(
f"W&B Logging - Enters logging function for model {kwargs}"
)
print_verbose(f"W&B Logging - Enters logging function for model {kwargs}")
run = wandb.init()
print_verbose(response_obj)
trace = self.resolver(kwargs, response_obj, (end_time-start_time).total_seconds())
trace = self.resolver(
kwargs, response_obj, (end_time - start_time).total_seconds()
)
if trace is not None:
run.log({"trace": trace})

View file

@ -7,17 +7,21 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message
import litellm
class AI21Error(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.ai21.com/studio/v1/")
self.request = httpx.Request(
method="POST", url="https://api.ai21.com/studio/v1/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AI21Config():
class AI21Config:
"""
Reference: https://docs.ai21.com/reference/j2-complete-ref
@ -43,40 +47,53 @@ class AI21Config():
- `countPenalty` (object): Placeholder for count penalty object.
"""
numResults: Optional[int]=None
maxTokens: Optional[int]=None
minTokens: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
stopSequences: Optional[list]=None
topKReturn: Optional[int]=None
frequencePenalty: Optional[dict]=None
presencePenalty: Optional[dict]=None
countPenalty: Optional[dict]=None
def __init__(self,
numResults: Optional[int]=None,
maxTokens: Optional[int]=None,
minTokens: Optional[int]=None,
temperature: Optional[float]=None,
topP: Optional[float]=None,
stopSequences: Optional[list]=None,
topKReturn: Optional[int]=None,
frequencePenalty: Optional[dict]=None,
presencePenalty: Optional[dict]=None,
countPenalty: Optional[dict]=None) -> None:
numResults: Optional[int] = None
maxTokens: Optional[int] = None
minTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
stopSequences: Optional[list] = None
topKReturn: Optional[int] = None
frequencePenalty: Optional[dict] = None
presencePenalty: Optional[dict] = None
countPenalty: Optional[dict] = None
def __init__(
self,
numResults: Optional[int] = None,
maxTokens: Optional[int] = None,
minTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
topKReturn: Optional[int] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
@ -91,6 +108,7 @@ def validate_environment(api_key):
}
return headers
def completion(
model: str,
messages: list,
@ -110,20 +128,18 @@ def completion(
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
## Load Config
config = litellm.AI21Config.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
@ -143,10 +159,7 @@ def completion(
api_base + model + "/complete", headers=headers, data=json.dumps(data)
)
if response.status_code != 200:
raise AI21Error(
status_code=response.status_code,
message=response.text
)
raise AI21Error(status_code=response.status_code, message=response.text)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
else:
@ -166,16 +179,20 @@ def completion(
message_obj = Message(content=item["data"]["text"])
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finishReason"]["reason"], index=idx+1, message=message_obj)
choice_obj = Choices(
finish_reason=item["finishReason"]["reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise AI21Error(message=traceback.format_exc(), status_code=response.status_code)
raise AI21Error(
message=traceback.format_exc(), status_code=response.status_code
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content"))
)
@ -189,6 +206,7 @@ def completion(
}
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -8,17 +8,21 @@ import litellm
from litellm.utils import ModelResponse, Choices, Message, Usage
import httpx
class AlephAlphaError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.aleph-alpha.com/complete")
self.request = httpx.Request(
method="POST", url="https://api.aleph-alpha.com/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AlephAlphaConfig():
class AlephAlphaConfig:
"""
Reference: https://docs.aleph-alpha.com/api/complete/
@ -72,83 +76,97 @@ class AlephAlphaConfig():
- `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
"""
maximum_tokens: Optional[int]=litellm.max_tokens # aleph alpha requires max tokens
minimum_tokens: Optional[int]=None
echo: Optional[bool]=None
temperature: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[int]=None
presence_penalty: Optional[int]=None
frequency_penalty: Optional[int]=None
sequence_penalty: Optional[int]=None
sequence_penalty_min_length: Optional[int]=None
repetition_penalties_include_prompt: Optional[bool]=None
repetition_penalties_include_completion: Optional[bool]=None
use_multiplicative_presence_penalty: Optional[bool]=None
use_multiplicative_frequency_penalty: Optional[bool]=None
use_multiplicative_sequence_penalty: Optional[bool]=None
penalty_bias: Optional[str]=None
penalty_exceptions_include_stop_sequences: Optional[bool]=None
best_of: Optional[int]=None
n: Optional[int]=None
logit_bias: Optional[dict]=None
log_probs: Optional[int]=None
stop_sequences: Optional[list]=None
tokens: Optional[bool]=None
raw_completion: Optional[bool]=None
disable_optimizations: Optional[bool]=None
completion_bias_inclusion: Optional[list]=None
completion_bias_exclusion: Optional[list]=None
completion_bias_inclusion_first_token_only: Optional[bool]=None
completion_bias_exclusion_first_token_only: Optional[bool]=None
contextual_control_threshold: Optional[int]=None
control_log_additive: Optional[bool]=None
maximum_tokens: Optional[
int
] = litellm.max_tokens # aleph alpha requires max tokens
minimum_tokens: Optional[int] = None
echo: Optional[bool] = None
temperature: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
presence_penalty: Optional[int] = None
frequency_penalty: Optional[int] = None
sequence_penalty: Optional[int] = None
sequence_penalty_min_length: Optional[int] = None
repetition_penalties_include_prompt: Optional[bool] = None
repetition_penalties_include_completion: Optional[bool] = None
use_multiplicative_presence_penalty: Optional[bool] = None
use_multiplicative_frequency_penalty: Optional[bool] = None
use_multiplicative_sequence_penalty: Optional[bool] = None
penalty_bias: Optional[str] = None
penalty_exceptions_include_stop_sequences: Optional[bool] = None
best_of: Optional[int] = None
n: Optional[int] = None
logit_bias: Optional[dict] = None
log_probs: Optional[int] = None
stop_sequences: Optional[list] = None
tokens: Optional[bool] = None
raw_completion: Optional[bool] = None
disable_optimizations: Optional[bool] = None
completion_bias_inclusion: Optional[list] = None
completion_bias_exclusion: Optional[list] = None
completion_bias_inclusion_first_token_only: Optional[bool] = None
completion_bias_exclusion_first_token_only: Optional[bool] = None
contextual_control_threshold: Optional[int] = None
control_log_additive: Optional[bool] = None
def __init__(self,
maximum_tokens: Optional[int]=None,
minimum_tokens: Optional[int]=None,
echo: Optional[bool]=None,
temperature: Optional[int]=None,
top_k: Optional[int]=None,
top_p: Optional[int]=None,
presence_penalty: Optional[int]=None,
frequency_penalty: Optional[int]=None,
sequence_penalty: Optional[int]=None,
sequence_penalty_min_length: Optional[int]=None,
repetition_penalties_include_prompt: Optional[bool]=None,
repetition_penalties_include_completion: Optional[bool]=None,
use_multiplicative_presence_penalty: Optional[bool]=None,
use_multiplicative_frequency_penalty: Optional[bool]=None,
use_multiplicative_sequence_penalty: Optional[bool]=None,
penalty_bias: Optional[str]=None,
penalty_exceptions_include_stop_sequences: Optional[bool]=None,
best_of: Optional[int]=None,
n: Optional[int]=None,
logit_bias: Optional[dict]=None,
log_probs: Optional[int]=None,
stop_sequences: Optional[list]=None,
tokens: Optional[bool]=None,
raw_completion: Optional[bool]=None,
disable_optimizations: Optional[bool]=None,
completion_bias_inclusion: Optional[list]=None,
completion_bias_exclusion: Optional[list]=None,
completion_bias_inclusion_first_token_only: Optional[bool]=None,
completion_bias_exclusion_first_token_only: Optional[bool]=None,
contextual_control_threshold: Optional[int]=None,
control_log_additive: Optional[bool]=None) -> None:
def __init__(
self,
maximum_tokens: Optional[int] = None,
minimum_tokens: Optional[int] = None,
echo: Optional[bool] = None,
temperature: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
presence_penalty: Optional[int] = None,
frequency_penalty: Optional[int] = None,
sequence_penalty: Optional[int] = None,
sequence_penalty_min_length: Optional[int] = None,
repetition_penalties_include_prompt: Optional[bool] = None,
repetition_penalties_include_completion: Optional[bool] = None,
use_multiplicative_presence_penalty: Optional[bool] = None,
use_multiplicative_frequency_penalty: Optional[bool] = None,
use_multiplicative_sequence_penalty: Optional[bool] = None,
penalty_bias: Optional[str] = None,
penalty_exceptions_include_stop_sequences: Optional[bool] = None,
best_of: Optional[int] = None,
n: Optional[int] = None,
logit_bias: Optional[dict] = None,
log_probs: Optional[int] = None,
stop_sequences: Optional[list] = None,
tokens: Optional[bool] = None,
raw_completion: Optional[bool] = None,
disable_optimizations: Optional[bool] = None,
completion_bias_inclusion: Optional[list] = None,
completion_bias_exclusion: Optional[list] = None,
completion_bias_inclusion_first_token_only: Optional[bool] = None,
completion_bias_exclusion_first_token_only: Optional[bool] = None,
contextual_control_threshold: Optional[int] = None,
control_log_additive: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
@ -160,6 +178,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -179,7 +198,9 @@ def completion(
## Load Config
config = litellm.AlephAlphaConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
completion_url = api_base
@ -188,21 +209,17 @@ def completion(
if "control" in model: # follow the ###Instruction / ###Response format
for idx, message in enumerate(messages):
if "role" in message:
if idx == 0: # set first message as instruction (required), let later user messages be input
if (
idx == 0
): # set first message as instruction (required), let later user messages be input
prompt += f"###Instruction: {message['content']}"
else:
if message["role"] == "system":
prompt += (
f"###Instruction: {message['content']}"
)
prompt += f"###Instruction: {message['content']}"
elif message["role"] == "user":
prompt += (
f"###Input: {message['content']}"
)
prompt += f"###Input: {message['content']}"
else:
prompt += (
f"###Response: {message['content']}"
)
prompt += f"###Response: {message['content']}"
else:
prompt += f"{message['content']}"
else:
@ -221,7 +238,10 @@ def completion(
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
@ -249,16 +269,21 @@ def completion(
message_obj = Message(content=item["completion"])
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except:
raise AlephAlphaError(message=json.dumps(completion_response), status_code=response.status_code)
raise AlephAlphaError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
)
@ -268,11 +293,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -9,52 +9,72 @@ import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
class AnthropicError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.anthropic.com/v1/complete")
self.request = httpx.Request(
method="POST", url="https://api.anthropic.com/v1/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AnthropicConfig():
class AnthropicConfig:
"""
Reference: https://docs.anthropic.com/claude/reference/complete_post
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens_to_sample: Optional[int]=litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list]=None
temperature: Optional[int]=None
top_p: Optional[int]=None
top_k: Optional[int]=None
metadata: Optional[dict]=None
def __init__(self,
max_tokens_to_sample: Optional[int]=256, # anthropic requires a default
stop_sequences: Optional[list]=None,
temperature: Optional[int]=None,
top_p: Optional[int]=None,
top_k: Optional[int]=None,
metadata: Optional[dict]=None) -> None:
max_tokens_to_sample: Optional[
int
] = litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
metadata: Optional[dict] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
stop_sequences: Optional[list] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
metadata: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# makes headers for API call
@ -71,6 +91,7 @@ def validate_environment(api_key):
}
return headers
def completion(
model: str,
messages: list,
@ -93,15 +114,19 @@ def completion(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
## Load Config
config = litellm.AnthropicConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
@ -127,15 +152,17 @@ def completion(
)
if response.status_code != 200:
raise AnthropicError(status_code=response.status_code, message=response.text)
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return response.iter_lines()
else:
response = requests.post(
api_base, headers=headers, data=json.dumps(data)
)
response = requests.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(status_code=response.status_code, message=response.text)
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
@ -159,9 +186,9 @@ def completion(
)
else:
if len(completion_response["completion"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response[
"completion"
]
model_response["choices"][0]["message"][
"content"
] = completion_response["completion"]
model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE
@ -177,11 +204,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -1,7 +1,13 @@
from typing import Optional, Union, Any
import types, requests
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
)
from typing import Callable, Optional
from litellm import OpenAIConfig
import litellm, json
@ -9,8 +15,15 @@ import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
class AzureOpenAIError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request:
@ -20,11 +33,14 @@ class AzureOpenAIError(Exception):
if response:
self.response = response
else:
self.response = httpx.Response(status_code=status_code, request=self.request)
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AzureOpenAIConfig(OpenAIConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -52,18 +68,21 @@ class AzureOpenAIConfig(OpenAIConfig):
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
def __init__(self,
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]]= None,
functions: Optional[list]= None,
logit_bias: Optional[dict]= None,
max_tokens: Optional[int]= None,
n: Optional[int]= None,
presence_penalty: Optional[int]= None,
stop: Optional[Union[str,list]]=None,
temperature: Optional[int]= None,
top_p: Optional[int]= None) -> None:
super().__init__(frequency_penalty,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
super().__init__(
frequency_penalty,
function_call,
functions,
logit_bias,
@ -72,10 +91,11 @@ class AzureOpenAIConfig(OpenAIConfig):
presence_penalty,
stop,
temperature,
top_p)
top_p,
)
class AzureChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
@ -89,7 +109,8 @@ class AzureChatCompletion(BaseLLM):
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers
def completion(self,
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
@ -105,15 +126,16 @@ class AzureChatCompletion(BaseLLM):
litellm_params,
logger_fn,
acompletion: bool = False,
headers: Optional[dict]=None,
client = None,
headers: Optional[dict] = None,
client=None,
):
super().completion()
exception_mapping_worked = False
try:
if model is None or messages is None:
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages")
raise AzureOpenAIError(
status_code=422, message=f"Missing model or messages"
)
max_retries = optional_params.pop("max_retries", 2)
@ -131,7 +153,7 @@ class AzureChatCompletion(BaseLLM):
"base_url": f"{api_base}",
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -143,25 +165,52 @@ class AzureChatCompletion(BaseLLM):
else:
client = AzureOpenAI(**azure_client_params)
data = {
"model": None,
"messages": messages,
**optional_params
}
data = {"model": None, "messages": messages, **optional_params}
else:
data = {
"model": model, # type: ignore
"messages": messages,
**optional_params
**optional_params,
}
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else:
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, client=client, logging_obj=logging_obj)
return self.acompletion(
api_base=api_base,
data=data,
model_response=model_response,
api_key=api_key,
api_version=api_version,
model=model,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
logging_obj=logging_obj,
)
elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else:
## LOGGING
logging_obj.pre_call(
@ -170,7 +219,7 @@ class AzureChatCompletion(BaseLLM):
additional_args={
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token
"azure_ad_token": azure_ad_token,
},
"api_version": api_version,
"api_base": api_base,
@ -178,7 +227,9 @@ class AzureChatCompletion(BaseLLM):
},
)
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
@ -186,7 +237,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -209,14 +260,18 @@ class AzureChatCompletion(BaseLLM):
"api_base": api_base,
},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
raise e
async def acompletion(self,
async def acompletion(
self,
api_key: str,
api_version: str,
model: str,
@ -224,15 +279,17 @@ class AzureChatCompletion(BaseLLM):
data: dict,
timeout: Any,
model_response: ModelResponse,
azure_ad_token: Optional[str]=None,
client = None, # this is the AsyncAzureOpenAI
azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI
logging_obj=None,
):
response = None
try:
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
@ -240,7 +297,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -252,12 +309,20 @@ class AzureChatCompletion(BaseLLM):
azure_client = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
input=data["messages"],
api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await azure_client.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
return convert_to_model_response_object(
response_object=json.loads(response.model_dump_json()),
model_response_object=model_response,
)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
@ -267,7 +332,8 @@ class AzureChatCompletion(BaseLLM):
else:
raise AzureOpenAIError(status_code=500, message=str(e))
def streaming(self,
def streaming(
self,
logging_obj,
api_base: str,
api_key: str,
@ -275,12 +341,14 @@ class AzureChatCompletion(BaseLLM):
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str]=None,
azure_ad_token: Optional[str] = None,
client=None,
):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
@ -288,7 +356,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -300,15 +368,26 @@ class AzureChatCompletion(BaseLLM):
azure_client = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
input=data["messages"],
api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
return streamwrapper
async def async_streaming(self,
async def async_streaming(
self,
logging_obj,
api_base: str,
api_key: str,
@ -316,8 +395,8 @@ class AzureChatCompletion(BaseLLM):
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str]=None,
client = None,
azure_ad_token: Optional[str] = None,
client=None,
):
# init AzureOpenAI Client
azure_client_params = {
@ -326,7 +405,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": data.pop("max_retries", 2),
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -338,12 +417,22 @@ class AzureChatCompletion(BaseLLM):
azure_client = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
input=data["messages"],
api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
@ -355,7 +444,7 @@ class AzureChatCompletion(BaseLLM):
api_key: str,
input: list,
client=None,
logging_obj=None
logging_obj=None,
):
response = None
try:
@ -372,7 +461,11 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
response_type="embedding",
)
except Exception as e:
## LOGGING
logging_obj.post_call(
@ -383,7 +476,8 @@ class AzureChatCompletion(BaseLLM):
)
raise e
def embedding(self,
def embedding(
self,
model: str,
input: list,
api_key: str,
@ -393,8 +487,8 @@ class AzureChatCompletion(BaseLLM):
logging_obj=None,
model_response=None,
optional_params=None,
azure_ad_token: Optional[str]=None,
client = None,
azure_ad_token: Optional[str] = None,
client=None,
aembedding=None,
):
super().embedding()
@ -402,14 +496,12 @@ class AzureChatCompletion(BaseLLM):
if self._client_session is None:
self._client_session = self.create_client_session()
try:
data = {
"model": model,
"input": input,
**optional_params
}
data = {"model": model, "input": input, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
@ -418,7 +510,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -431,15 +523,19 @@ class AzureChatCompletion(BaseLLM):
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token
}
"headers": {"api_key": api_key, "azure_ad_token": azure_ad_token},
},
)
if aembedding == True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params)
response = self.aembedding(
data=data,
input=input,
logging_obj=logging_obj,
api_key=api_key,
model_response=model_response,
azure_client_params=azure_client_params,
)
return response
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
@ -455,7 +551,6 @@ class AzureChatCompletion(BaseLLM):
original_response=response,
)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore
except AzureOpenAIError as e:
exception_mapping_worked = True
@ -465,6 +560,7 @@ class AzureChatCompletion(BaseLLM):
raise e
else:
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
async def aimage_generation(
@ -475,13 +571,17 @@ class AzureChatCompletion(BaseLLM):
api_key: str,
input: list,
client=None,
logging_obj=None
logging_obj=None,
):
response = None
try:
if client is None:
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),)
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params)
client_session = litellm.aclient_session or httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
)
openai_aclient = AsyncAzureOpenAI(
http_client=client_session, **azure_client_params
)
else:
openai_aclient = client
response = await openai_aclient.images.generate(**data)
@ -493,7 +593,11 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation")
return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
response_type="image_generation",
)
except Exception as e:
## LOGGING
logging_obj.post_call(
@ -504,15 +608,16 @@ class AzureChatCompletion(BaseLLM):
)
raise e
def image_generation(self,
def image_generation(
self,
prompt: str,
timeout: float,
model: Optional[str]=None,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
azure_ad_token: Optional[str]=None,
azure_ad_token: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None,
@ -524,14 +629,12 @@ class AzureChatCompletion(BaseLLM):
model = model
else:
model = None
data = {
"model": model,
"prompt": prompt,
**optional_params
}
data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
@ -539,7 +642,7 @@ class AzureChatCompletion(BaseLLM):
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -551,7 +654,9 @@ class AzureChatCompletion(BaseLLM):
return response
if client is None:
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
client_session = litellm.client_session or httpx.Client(
transport=CustomHTTPTransport(),
)
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
else:
azure_client = client
@ -560,7 +665,12 @@ class AzureChatCompletion(BaseLLM):
logging_obj.pre_call(
input=prompt,
api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data},
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": False,
"complete_input_dict": data,
},
)
## COMPLETION CALL
@ -582,4 +692,5 @@ class AzureChatCompletion(BaseLLM):
raise e
else:
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())

View file

@ -3,8 +3,10 @@ import litellm
import httpx, certifi, ssl
from typing import Optional
class BaseLLM:
_client_session: Optional[httpx.Client] = None
def create_client_session(self):
if litellm.client_session:
_client_session = litellm.client_session
@ -22,26 +24,22 @@ class BaseLLM:
return _aclient_session
def __exit__(self):
if hasattr(self, '_client_session'):
if hasattr(self, "_client_session"):
self._client_session.close()
async def __aexit__(self, exc_type, exc_val, exc_tb):
if hasattr(self, '_aclient_session'):
if hasattr(self, "_aclient_session"):
await self._aclient_session.aclose()
def validate_environment(self): # set up the environment required to run the model
pass
def completion(
self,
*args,
**kwargs
self, *args, **kwargs
): # logic for parsing in - calling - parsing out model completion calls
pass
def embedding(
self,
*args,
**kwargs
self, *args, **kwargs
): # logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -6,6 +6,7 @@ import time
from typing import Callable
from litellm.utils import ModelResponse, Usage
class BasetenError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -14,6 +15,7 @@ class BasetenError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
def validate_environment(api_key):
headers = {
"accept": "application/json",
@ -23,6 +25,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Api-Key {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -52,7 +55,9 @@ def completion(
"inputs": prompt,
"prompt": prompt,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
}
## LOGGING
@ -66,9 +71,13 @@ def completion(
completion_url_fragment_1 + model + completion_url_fragment_2,
headers=headers,
data=json.dumps(data),
stream=True if "stream" in optional_params and optional_params["stream"] == True else False
stream=True
if "stream" in optional_params and optional_params["stream"] == True
else False,
)
if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True):
if "text/event-stream" in response.headers["Content-Type"] or (
"stream" in optional_params and optional_params["stream"] == True
):
return response.iter_lines()
else:
## LOGGING
@ -91,9 +100,7 @@ def completion(
if (
isinstance(completion_response["model_output"], dict)
and "data" in completion_response["model_output"]
and isinstance(
completion_response["model_output"]["data"], list
)
and isinstance(completion_response["model_output"]["data"], list)
):
model_response["choices"][0]["message"][
"content"
@ -112,12 +119,19 @@ def completion(
if "generated_text" not in completion_response:
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code
status_code=response.status_code,
)
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## GETTING LOGPROBS
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
@ -125,7 +139,7 @@ def completion(
else:
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code
status_code=response.status_code,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
@ -139,11 +153,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, get_secret, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx
class BedrockError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock")
self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AmazonTitanConfig():
class AmazonTitanConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
@ -29,29 +33,44 @@ class AmazonTitanConfig():
- `temperature` (float) temperature for model,
- `topP` (int) top p for model
"""
maxTokenCount: Optional[int]=None
stopSequences: Optional[list]=None
temperature: Optional[float]=None
topP: Optional[int]=None
def __init__(self,
maxTokenCount: Optional[int]=None,
stopSequences: Optional[list]=None,
temperature: Optional[float]=None,
topP: Optional[int]=None) -> None:
maxTokenCount: Optional[int] = None
stopSequences: Optional[list] = None
temperature: Optional[float] = None
topP: Optional[int] = None
def __init__(
self,
maxTokenCount: Optional[int] = None,
stopSequences: Optional[list] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonAnthropicConfig():
class AmazonAnthropicConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
@ -64,33 +83,48 @@ class AmazonAnthropicConfig():
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
"""
max_tokens_to_sample: Optional[int]=litellm.max_tokens
stop_sequences: Optional[list]=None
temperature: Optional[float]=None
top_k: Optional[int]=None
top_p: Optional[int]=None
anthropic_version: Optional[str]=None
def __init__(self,
max_tokens_to_sample: Optional[int]=None,
stop_sequences: Optional[list]=None,
temperature: Optional[float]=None,
top_k: Optional[int]=None,
top_p: Optional[int]=None,
anthropic_version: Optional[str]=None) -> None:
max_tokens_to_sample: Optional[int] = litellm.max_tokens
stop_sequences: Optional[list] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
anthropic_version: Optional[str] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = None,
stop_sequences: Optional[list] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonCohereConfig():
class AmazonCohereConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
@ -100,27 +134,42 @@ class AmazonCohereConfig():
- `temperature` (float) model temperature,
- `return_likelihood` (string) n/a
"""
max_tokens: Optional[int]=None
temperature: Optional[float]=None
return_likelihood: Optional[str]=None
def __init__(self,
max_tokens: Optional[int]=None,
temperature: Optional[float]=None,
return_likelihood: Optional[str]=None) -> None:
max_tokens: Optional[int] = None
temperature: Optional[float] = None
return_likelihood: Optional[str] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
return_likelihood: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonAI21Config():
class AmazonAI21Config:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
@ -140,39 +189,55 @@ class AmazonAI21Config():
- `countPenalty` (object): Placeholder for count penalty object.
"""
maxTokens: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
stopSequences: Optional[list]=None
frequencePenalty: Optional[dict]=None
presencePenalty: Optional[dict]=None
countPenalty: Optional[dict]=None
def __init__(self,
maxTokens: Optional[int]=None,
temperature: Optional[float]=None,
topP: Optional[float]=None,
stopSequences: Optional[list]=None,
frequencePenalty: Optional[dict]=None,
presencePenalty: Optional[dict]=None,
countPenalty: Optional[dict]=None) -> None:
maxTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
stopSequences: Optional[list] = None
frequencePenalty: Optional[dict] = None
presencePenalty: Optional[dict] = None
countPenalty: Optional[dict] = None
def __init__(
self,
maxTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
class AmazonLlamaConfig():
class AmazonLlamaConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
@ -182,48 +247,72 @@ class AmazonLlamaConfig():
- `temperature` (float) temperature for model,
- `top_p` (float) top p for model
"""
max_gen_len: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
def __init__(self,
maxTokenCount: Optional[int]=None,
temperature: Optional[float]=None,
topP: Optional[int]=None) -> None:
max_gen_len: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
def __init__(
self,
maxTokenCount: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def init_bedrock_client(
region_name = None,
region_name=None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] =None,
aws_bedrock_runtime_endpoint: Optional[str]=None,
):
aws_region_name: Optional[str] = None,
aws_bedrock_runtime_endpoint: Optional[str] = None,
):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint]
params_to_check = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith('os.environ/'):
if param and param.startswith("os.environ/"):
params_to_check[i] = get_secret(param)
# Assign updated values back to parameters
aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
) = params_to_check
if region_name:
pass
elif aws_region_name:
@ -233,7 +322,10 @@ def init_bedrock_client(
elif standard_aws_region_name:
region_name = standard_aws_region_name
else:
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", status_code=401)
raise BedrockError(
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
status_code=401,
)
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
@ -242,9 +334,10 @@ def init_bedrock_client(
elif env_aws_bedrock_runtime_endpoint:
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f'https://bedrock-runtime.{region_name}.amazonaws.com'
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
import boto3
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
@ -279,22 +372,20 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
else:
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
return prompt
@ -309,6 +400,7 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
def completion(
model: str,
messages: list,
@ -327,7 +419,9 @@ def completion(
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop(
@ -343,67 +437,71 @@ def completion(
model = model
provider = model.split(".")[0]
prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict)
prompt = convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", False)
if provider == "anthropic":
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
"prompt": prompt,
**inference_params
})
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
"prompt": prompt,
**inference_params
})
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "cohere":
## LOAD CONFIG
config = litellm.AmazonCohereConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if optional_params.get("stream", False) == True:
inference_params["stream"] = True # cohere requires stream = True in inference params
data = json.dumps({
"prompt": prompt,
**inference_params
})
inference_params[
"stream"
] = True # cohere requires stream = True in inference params
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "meta":
## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
"prompt": prompt,
**inference_params
})
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "amazon": # amazon titan
## LOAD CONFIG
config = litellm.AmazonTitanConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in inference_params
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
data = json.dumps(
{
"inputText": prompt,
"textGenerationConfig": inference_params,
})
}
)
## COMPLETION CALL
accept = 'application/json'
contentType = 'application/json'
accept = "application/json"
contentType = "application/json"
if stream == True:
if provider == "ai21":
## LOGGING
@ -418,17 +516,17 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
response = client.invoke_model(
body=data,
modelId=model,
accept=accept,
contentType=contentType
body=data, modelId=model, accept=accept, contentType=contentType
)
response = response.get('body').read()
response = response.get("body").read()
return response
else:
## LOGGING
@ -443,16 +541,16 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
response = client.invoke_model_with_response_stream(
body=data,
modelId=model,
accept=accept,
contentType=contentType
body=data, modelId=model, accept=accept, contentType=contentType
)
response = response.get('body')
response = response.get("body")
return response
try:
## LOGGING
@ -467,18 +565,18 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
response = client.invoke_model(
body=data,
modelId=model,
accept=accept,
contentType=contentType
body=data, modelId=model, accept=accept, contentType=contentType
)
except Exception as e:
raise BedrockError(status_code=500, message=str(e))
response_body = json.loads(response.get('body').read())
response_body = json.loads(response.get("body").read())
## LOGGING
logging_obj.post_call(
@ -491,16 +589,16 @@ def completion(
## RESPONSE OBJECT
outputText = "default"
if provider == "ai21":
outputText = response_body.get('completions')[0].get('data').get('text')
outputText = response_body.get("completions")[0].get("data").get("text")
elif provider == "anthropic":
outputText = response_body['completion']
outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"]
elif provider == "cohere":
outputText = response_body["generations"][0]["text"]
elif provider == "meta":
outputText = response_body["generation"]
else: # amazon titan
outputText = response_body.get('results')[0].get('outputText')
outputText = response_body.get("results")[0].get("outputText")
response_metadata = response.get("ResponseMetadata", {})
if response_metadata.get("HTTPStatusCode", 500) >= 400:
@ -513,12 +611,13 @@ def completion(
if len(outputText) > 0:
model_response["choices"][0]["message"]["content"] = outputText
except:
raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500))
raise BedrockError(
message=json.dumps(outputText),
status_code=response_metadata.get("HTTPStatusCode", 500),
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -528,7 +627,7 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens = prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
@ -540,8 +639,10 @@ def completion(
raise e
else:
import traceback
raise BedrockError(status_code=500, message=traceback.format_exc())
def _embedding_func_single(
model: str,
input: str,
@ -554,13 +655,17 @@ def _embedding_func_single(
## FORMAT EMBEDDING INPUT ##
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop("user", None) # make sure user is not passed in for bedrock call
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
if provider == "amazon":
input = input.replace(os.linesep, " ")
data = {"inputText": input, **inference_params}
# data = json.dumps(data)
elif provider == "cohere":
inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
inference_params["input_type"] = inference_params.get(
"input_type", "search_document"
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
data = {"texts": [input], **inference_params} # type: ignore
body = json.dumps(data).encode("utf-8")
## LOGGING
@ -574,8 +679,10 @@ def _embedding_func_single(
logging_obj.pre_call(
input=input,
api_key="", # boto3 is used for init.
additional_args={"complete_input_dict": {"model": model,
"texts": input}, "request_str": request_str},
additional_args={
"complete_input_dict": {"model": model, "texts": input},
"request_str": request_str,
},
)
try:
response = client.invoke_model(
@ -600,7 +707,10 @@ def _embedding_func_single(
elif provider == "amazon":
return response_body.get("embedding")
except Exception as e:
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
raise BedrockError(
message=f"Embedding Error with model {model}: {e}", status_code=500
)
def embedding(
model: str,
@ -616,7 +726,9 @@ def embedding(
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
@ -627,8 +739,16 @@ def embedding(
)
## Embedding Call
embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls
embeddings = [
_embedding_func_single(
model,
i,
optional_params=optional_params,
client=client,
logging_obj=logging_obj,
)
for i in input
] # [TODO]: make these parallel calls
## Populate OpenAI compliant dictionary
embedding_response = []
@ -647,12 +767,10 @@ def embedding(
input_str = "".join(input)
input_tokens+=len(encoding.encode(input_str))
input_tokens += len(encoding.encode(input_str))
usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens + 0
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + 0
)
model_response.usage = usage

View file

@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import httpx
class CohereError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/generate")
self.request = httpx.Request(
method="POST", url="https://api.cohere.ai/v1/generate"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class CohereConfig():
class CohereConfig:
"""
Reference: https://docs.cohere.com/reference/generate
@ -50,46 +54,60 @@ class CohereConfig():
- `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233}
"""
num_generations: Optional[int]=None
max_tokens: Optional[int]=None
truncate: Optional[str]=None
temperature: Optional[int]=None
preset: Optional[str]=None
end_sequences: Optional[list]=None
stop_sequences: Optional[list]=None
k: Optional[int]=None
p: Optional[int]=None
frequency_penalty: Optional[int]=None
presence_penalty: Optional[int]=None
return_likelihoods: Optional[str]=None
logit_bias: Optional[dict]=None
def __init__(self,
num_generations: Optional[int]=None,
max_tokens: Optional[int]=None,
truncate: Optional[str]=None,
temperature: Optional[int]=None,
preset: Optional[str]=None,
end_sequences: Optional[list]=None,
stop_sequences: Optional[list]=None,
k: Optional[int]=None,
p: Optional[int]=None,
frequency_penalty: Optional[int]=None,
presence_penalty: Optional[int]=None,
return_likelihoods: Optional[str]=None,
logit_bias: Optional[dict]=None) -> None:
num_generations: Optional[int] = None
max_tokens: Optional[int] = None
truncate: Optional[str] = None
temperature: Optional[int] = None
preset: Optional[str] = None
end_sequences: Optional[list] = None
stop_sequences: Optional[list] = None
k: Optional[int] = None
p: Optional[int] = None
frequency_penalty: Optional[int] = None
presence_penalty: Optional[int] = None
return_likelihoods: Optional[str] = None
logit_bias: Optional[dict] = None
def __init__(
self,
num_generations: Optional[int] = None,
max_tokens: Optional[int] = None,
truncate: Optional[str] = None,
temperature: Optional[int] = None,
preset: Optional[str] = None,
end_sequences: Optional[list] = None,
stop_sequences: Optional[list] = None,
k: Optional[int] = None,
p: Optional[int] = None,
frequency_penalty: Optional[int] = None,
presence_penalty: Optional[int] = None,
return_likelihoods: Optional[str] = None,
logit_bias: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
headers = {
@ -100,6 +118,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -119,9 +138,11 @@ def completion(
prompt = " ".join(message["content"] for message in messages)
## Load Config
config=litellm.CohereConfig.get_config()
config = litellm.CohereConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
@ -134,14 +155,21 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url},
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
## error handling for cohere calls
if response.status_code!=200:
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] == True:
@ -170,16 +198,20 @@ def completion(
message_obj = Message(content=item["text"])
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise CohereError(message=response.text, status_code=response.status_code)
raise CohereError(
message=response.text, status_code=response.status_code
)
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -189,11 +221,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding(
model: str,
input: list,
@ -206,11 +239,7 @@ def embedding(
headers = validate_environment(api_key)
embed_url = "https://api.cohere.ai/v1/embed"
model = model
data = {
"model": model,
"texts": input,
**optional_params
}
data = {"model": model, "texts": input, **optional_params}
if "3" in model and "input_type" not in data:
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
@ -223,9 +252,7 @@ def embedding(
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
embed_url, headers=headers, data=json.dumps(data)
)
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
@ -244,30 +271,23 @@ def embedding(
'usage'
}
"""
if response.status_code!=200:
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
embeddings = response.json()['embeddings']
embeddings = response.json()["embeddings"]
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding
}
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
input_tokens += len(encoding.encode(text))
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
return model_response

View file

@ -1,9 +1,11 @@
import time, json, httpx, asyncio
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
"""
Async implementation of custom http transport
"""
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[
"api-version"
@ -14,7 +16,9 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
"2023-09-01-preview",
"2023-10-01-preview",
]:
request.url = request.url.copy_with(path="/openai/images/generations:submit")
request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)
response = await super().handle_async_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
@ -26,7 +30,12 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
timeout = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
return httpx.Response(
status_code=400,
headers=response.headers,
@ -56,12 +65,14 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
)
return await super().handle_async_request(request)
class CustomHTTPTransport(httpx.HTTPTransport):
"""
This class was written as a workaround to support dall-e-2 on openai > v1.x
Refer to this issue for more: https://github.com/openai/openai-python/issues/692
"""
def handle_request(
self,
request: httpx.Request,
@ -75,7 +86,9 @@ class CustomHTTPTransport(httpx.HTTPTransport):
"2023-09-01-preview",
"2023-10-01-preview",
]:
request.url = request.url.copy_with(path="/openai/images/generations:submit")
request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)
response = super().handle_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
@ -87,7 +100,12 @@ class CustomHTTPTransport(httpx.HTTPTransport):
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
timeout = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
return httpx.Response(
status_code=400,
headers=response.headers,

View file

@ -8,17 +8,22 @@ import litellm
import sys, httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
class GeminiError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat")
self.request = httpx.Request(
method="POST",
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class GeminiConfig():
class GeminiConfig:
"""
Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig
@ -37,33 +42,44 @@ class GeminiConfig():
- `top_k` (int): Optional. The maximum number of tokens to consider when sampling.
"""
candidate_count: Optional[int]=None
stop_sequences: Optional[list]=None
max_output_tokens: Optional[int]=None
temperature: Optional[float]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
def __init__(self,
candidate_count: Optional[int]=None,
stop_sequences: Optional[list]=None,
max_output_tokens: Optional[int]=None,
temperature: Optional[float]=None,
top_p: Optional[float]=None,
top_k: Optional[int]=None) -> None:
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
max_output_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
def __init__(
self,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion(
@ -83,10 +99,11 @@ def completion(
try:
import google.generativeai as genai
except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai")
raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
genai.configure(api_key=api_key)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
@ -94,21 +111,25 @@ def completion(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="gemini")
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="gemini"
)
## Load Config
inference_params = copy.deepcopy(optional_params)
inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py
inference_params.pop(
"stream", None
) # palm does not support streaming, so we handle this by fake streaming in main.py
config = litellm.GeminiConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in inference_params
): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -117,8 +138,11 @@ def completion(
)
## COMPLETION CALL
try:
_model = genai.GenerativeModel(f'models/{model}')
response = _model.generate_content(contents=prompt, generation_config=genai.types.GenerationConfig(**inference_params))
_model = genai.GenerativeModel(f"models/{model}")
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
)
except Exception as e:
raise GeminiError(
message=str(e),
@ -142,17 +166,22 @@ def completion(
message_obj = Message(content=item.content.parts[0].text)
else:
message_obj = Message(content=None)
choice_obj = Choices(index=idx+1, message=message_obj)
choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
traceback.print_exc()
raise GeminiError(message=traceback.format_exc(), status_code=response.status_code)
raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code
)
try:
completion_response = model_response["choices"][0]["message"].get("content")
except:
raise GeminiError(status_code=400, message=f"No response received. Original response - {response}")
raise GeminiError(
status_code=400,
message=f"No response received. Original response - {response}",
)
## CALCULATING USAGE
prompt_str = ""
@ -164,9 +193,7 @@ def completion(
if content["type"] == "text":
prompt_str += content["text"]
prompt_tokens = len(
encoding.encode(prompt_str)
)
prompt_tokens = len(encoding.encode(prompt_str))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -176,11 +203,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -11,32 +11,47 @@ from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper,
from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request is not None:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api-inference.huggingface.co/models")
self.request = httpx.Request(
method="POST", url="https://api-inference.huggingface.co/models"
)
if response is not None:
self.response = response
else:
self.response = httpx.Response(status_code=status_code, request=self.request)
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class HuggingfaceConfig():
class HuggingfaceConfig:
"""
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
"""
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
return_full_text: Optional[bool] = False # by default don't return the input as part of the output
return_full_text: Optional[
bool
] = False # by default don't return the input as part of the output
seed: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
@ -46,7 +61,8 @@ class HuggingfaceConfig():
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(self,
def __init__(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
@ -60,19 +76,31 @@ class HuggingfaceConfig():
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None
watermark: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def output_parser(generated_text: str):
"""
@ -88,8 +116,11 @@ def output_parser(generated_text: str):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
tgi_models_cache = None
conv_models_cache = None
def read_tgi_conv_models():
try:
global tgi_models_cache, conv_models_cache
@ -101,9 +132,13 @@ def read_tgi_conv_models():
tgi_models = set()
script_directory = os.path.dirname(os.path.abspath(__file__))
# Construct the file path relative to the script's directory
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_text_generation_models.txt")
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_text_generation_models.txt",
)
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
for line in file:
tgi_models.add(line.strip())
@ -111,9 +146,13 @@ def read_tgi_conv_models():
tgi_models_cache = tgi_models
# If not, read the file and populate the cache
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_conversational_models.txt")
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_conversational_models.txt",
)
conv_models = set()
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
for line in file:
conv_models.add(line.strip())
# Cache the set for future use
@ -136,6 +175,7 @@ def get_hf_task_for_model(model):
else:
return "text-generation-inference" # default to tgi
class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
@ -148,65 +188,93 @@ class Huggingface(BaseLLM):
"content-type": "application/json",
}
if api_key and headers is None:
default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
default_headers[
"Authorization"
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = default_headers
elif headers:
headers=headers
headers = headers
else:
headers = default_headers
return headers
def convert_to_model_response_object(self,
def convert_to_model_response_object(
self,
completion_response,
model_response,
task,
optional_params,
encoding,
input_text,
model):
model,
):
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"] # type: ignore
] = completion_response[
"generated_text"
] # type: ignore
elif task == "text-generation-inference":
if (not isinstance(completion_response, list)
if (
not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]):
raise HuggingfaceError(status_code=422, message=f"response is not in expected format - {completion_response}")
or "generated_text" not in completion_response[0]
):
raise HuggingfaceError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
)
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = output_parser(completion_response[0]["generated_text"])
model_response["choices"][0]["message"]["content"] = output_parser(
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]._logprob = sum_logprob
if "best_of" in optional_params and optional_params["best_of"] > 1:
if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]:
if (
"details" in completion_response[0]
and "best_of_sequences" in completion_response[0]["details"]
):
choices_list = []
for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]):
for idx, item in enumerate(
completion_response[0]["details"]["best_of_sequences"]
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob)
message_obj = Message(
content=output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
else:
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = output_parser(completion_response[0]["generated_text"])
model_response["choices"][0]["message"]["content"] = output_parser(
completion_response[0]["generated_text"]
)
## CALCULATING USAGE
prompt_tokens = 0
try:
@ -221,7 +289,9 @@ class Huggingface(BaseLLM):
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here
except:
# this should remain non blocking we should not block a response returning if calculating usage fails
@ -234,13 +304,14 @@ class Huggingface(BaseLLM):
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
model_response._hidden_params["original_response"] = completion_response
return model_response
def completion(self,
def completion(
self,
model: str,
messages: list,
api_base: Optional[str],
@ -276,9 +347,11 @@ class Huggingface(BaseLLM):
completion_url = f"https://api-inference.huggingface.co/models/{model}"
## Load Config
config=litellm.HuggingfaceConfig.get_config()
config = litellm.HuggingfaceConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
### MAP INPUT PARAMS
@ -300,9 +373,9 @@ class Huggingface(BaseLLM):
"inputs": {
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses
"generated_responses": generated_responses,
},
"parameters": inference_params
"parameters": inference_params,
}
input_text = "".join(message["content"] for message in messages)
elif task == "text-generation-inference":
@ -312,16 +385,22 @@ class Huggingface(BaseLLM):
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
data = {
"inputs": prompt,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
}
input_text = prompt
else:
@ -332,8 +411,12 @@ class Huggingface(BaseLLM):
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
@ -346,14 +429,22 @@ class Huggingface(BaseLLM):
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
}
input_text = prompt
## LOGGING
logging_obj.pre_call(
input=input_text,
api_key=api_key,
additional_args={"complete_input_dict": data, "task": task, "headers": headers, "api_base": completion_url, "acompletion": acompletion},
additional_args={
"complete_input_dict": data,
"task": task,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL
if acompletion is True:
@ -369,29 +460,37 @@ class Huggingface(BaseLLM):
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"]
stream=optional_params["stream"],
)
return response.iter_lines()
### SYNC COMPLETION
else:
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data)
completion_url, headers=headers, data=json.dumps(data)
)
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
is_streamed = False
if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream":
if (
response.__dict__["headers"].get("Content-Type", "")
== "text/event-stream"
):
is_streamed = True
# iterate over the complete streamed response, and return the final answer
if is_streamed:
streamed_response = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="huggingface", logging_obj=logging_obj)
streamed_response = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
content = ""
for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"]
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
completion_response: List[Dict[str, Any]] = [
{"generated_text": content}
]
## LOGGING
logging_obj.post_call(
input=input_text,
@ -414,11 +513,16 @@ class Huggingface(BaseLLM):
completion_response = [completion_response]
except:
import traceback
raise HuggingfaceError(
message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}", status_code=response.status_code
message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}",
status_code=response.status_code,
)
print_verbose(f"response: {completion_response}")
if isinstance(completion_response, dict) and "error" in completion_response:
if (
isinstance(completion_response, dict)
and "error" in completion_response
):
print_verbose(f"completion error: {completion_response['error']}")
print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError(
@ -432,7 +536,7 @@ class Huggingface(BaseLLM):
optional_params=optional_params,
encoding=encoding,
input_text=input_text,
model=model
model=model,
)
except HuggingfaceError as e:
exception_mapping_worked = True
@ -442,9 +546,11 @@ class Huggingface(BaseLLM):
raise e
else:
import traceback
raise HuggingfaceError(status_code=500, message=traceback.format_exc())
async def acompletion(self,
async def acompletion(
self,
api_base: str,
data: dict,
headers: dict,
@ -453,54 +559,75 @@ class Huggingface(BaseLLM):
encoding: Any,
input_text: str,
model: str,
optional_params: dict):
optional_params: dict,
):
response = None
try:
async with httpx.AsyncClient() as client:
response = await client.post(url=api_base, json=data, headers=headers, timeout=None)
response = await client.post(
url=api_base, json=data, headers=headers, timeout=None
)
response_json = response.json()
if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
raise HuggingfaceError(
status_code=response.status_code,
message=response.text,
request=response.request,
response=response,
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(completion_response=response_json,
return self.convert_to_model_response_object(
completion_response=response_json,
model_response=model_response,
task=task,
encoding=encoding,
input_text=input_text,
model=model,
optional_params=optional_params)
optional_params=optional_params,
)
except Exception as e:
if isinstance(e,httpx.TimeoutException):
if isinstance(e, httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
elif response is not None and hasattr(response, "text"):
raise HuggingfaceError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
raise HuggingfaceError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
)
else:
raise HuggingfaceError(status_code=500, message=f"{str(e)}")
async def async_streaming(self,
async def async_streaming(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str):
model: str,
):
async with httpx.AsyncClient() as client:
response = client.stream(
"POST",
url=f"{api_base}",
json=data,
headers=headers
"POST", url=f"{api_base}", json=data, headers=headers
)
async with response as r:
if r.status_code != 200:
raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming")
raise HuggingfaceError(
status_code=r.status_code,
message="An error occurred while streaming",
)
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(self,
def embedding(
self,
model: str,
input: list,
api_key: Optional[str] = None,
@ -526,29 +653,35 @@ class Huggingface(BaseLLM):
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences")
raise HuggingfaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = {
"inputs": {
"source_sentence": input[0],
"sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ]
"sentences": [
"That is a happy dog",
"That is a very happy person",
"Today is a sunny day",
],
}
}
else:
data = {
"inputs": input # type: ignore
}
data = {"inputs": input} # type: ignore
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": embed_url},
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
response = requests.post(
embed_url, headers=headers, data=json.dumps(data)
)
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
@ -558,11 +691,10 @@ class Huggingface(BaseLLM):
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings['error'])
raise HuggingfaceError(status_code=500, message=embeddings["error"])
output_data = []
if "similarities" in embeddings:
@ -571,7 +703,7 @@ class Huggingface(BaseLLM):
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
"embedding": embedding, # flatten list returned from hf
}
)
else:
@ -581,7 +713,7 @@ class Huggingface(BaseLLM):
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
"embedding": embedding, # flatten list returned from hf
}
)
elif isinstance(embedding, list) and isinstance(embedding[0], float):
@ -589,7 +721,7 @@ class Huggingface(BaseLLM):
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
"embedding": embedding, # flatten list returned from hf
}
)
else:
@ -597,7 +729,9 @@ class Huggingface(BaseLLM):
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][0] # flatten list returned from hf
"embedding": embedding[0][
0
], # flatten list returned from hf
}
)
model_response["object"] = "list"
@ -605,13 +739,10 @@ class Huggingface(BaseLLM):
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
input_tokens += len(encoding.encode(text))
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
return model_response

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
class MaritalkError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -15,7 +16,8 @@ class MaritalkError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
class MaritTalkConfig():
class MaritTalkConfig:
"""
The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters:
@ -33,6 +35,7 @@ class MaritTalkConfig():
- `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped.
"""
max_tokens: Optional[int] = None
model: Optional[str] = None
do_sample: Optional[bool] = None
@ -41,26 +44,39 @@ class MaritTalkConfig():
repetition_penalty: Optional[float] = None
stopping_tokens: Optional[List[str]] = None
def __init__(self,
max_tokens: Optional[int]=None,
def __init__(
self,
max_tokens: Optional[int] = None,
model: Optional[str] = None,
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
stopping_tokens: Optional[List[str]] = None) -> None:
stopping_tokens: Optional[List[str]] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
headers = {
@ -71,6 +87,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Key {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -89,9 +106,11 @@ def completion(
model = model
## Load Config
config=litellm.MaritTalkConfig.get_config()
config = litellm.MaritTalkConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
@ -107,7 +126,10 @@ def completion(
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
@ -130,15 +152,17 @@ def completion(
else:
try:
if len(completion_response["answer"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response["answer"]
model_response["choices"][0]["message"][
"content"
] = completion_response["answer"]
except Exception as e:
raise MaritalkError(message=response.text, status_code=response.status_code)
raise MaritalkError(
message=response.text, status_code=response.status_code
)
## CALCULATING USAGE
prompt = "".join(m["content"] for m in messages)
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -148,11 +172,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding(
model: str,
input: list,

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional
import litellm
from litellm.utils import ModelResponse, Usage
class NLPCloudError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -15,7 +16,8 @@ class NLPCloudError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
class NLPCloudConfig():
class NLPCloudConfig:
"""
Reference: https://docs.nlpcloud.com/#generation
@ -43,45 +45,57 @@ class NLPCloudConfig():
- `num_return_sequences` (int): Optional. The number of independently computed returned sequences.
"""
max_length: Optional[int]=None
length_no_input: Optional[bool]=None
end_sequence: Optional[str]=None
remove_end_sequence: Optional[bool]=None
remove_input: Optional[bool]=None
bad_words: Optional[list]=None
temperature: Optional[float]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
repetition_penalty: Optional[float]=None
num_beams: Optional[int]=None
num_return_sequences: Optional[int]=None
max_length: Optional[int] = None
length_no_input: Optional[bool] = None
end_sequence: Optional[str] = None
remove_end_sequence: Optional[bool] = None
remove_input: Optional[bool] = None
bad_words: Optional[list] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
repetition_penalty: Optional[float] = None
num_beams: Optional[int] = None
num_return_sequences: Optional[int] = None
def __init__(self,
max_length: Optional[int]=None,
length_no_input: Optional[bool]=None,
end_sequence: Optional[str]=None,
remove_end_sequence: Optional[bool]=None,
remove_input: Optional[bool]=None,
bad_words: Optional[list]=None,
temperature: Optional[float]=None,
top_p: Optional[float]=None,
top_k: Optional[int]=None,
repetition_penalty: Optional[float]=None,
num_beams: Optional[int]=None,
num_return_sequences: Optional[int]=None) -> None:
def __init__(
self,
max_length: Optional[int] = None,
length_no_input: Optional[bool] = None,
end_sequence: Optional[str] = None,
remove_end_sequence: Optional[bool] = None,
remove_input: Optional[bool] = None,
bad_words: Optional[list] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
num_beams: Optional[int] = None,
num_return_sequences: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
@ -93,6 +107,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Token {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -112,7 +127,9 @@ def completion(
## Load Config
config = litellm.NLPCloudConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
completion_url_fragment_1 = api_base
@ -131,11 +148,18 @@ def completion(
logging_obj.pre_call(
input=text,
api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url},
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
return clean_and_iterate_chunks(response)
@ -161,9 +185,14 @@ def completion(
else:
try:
if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response["generated_text"]
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"]
except:
raise NLPCloudError(message=json.dumps(completion_response), status_code=response.status_code)
raise NLPCloudError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = completion_response["nb_input_tokens"]
@ -174,7 +203,7 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
@ -187,25 +216,27 @@ def completion(
# # Perform further processing based on your needs
# return cleaned_chunk
# for line in response.iter_lines():
# if line:
# yield process_chunk(line)
def clean_and_iterate_chunks(response):
buffer = b''
buffer = b""
for chunk in response.iter_content(chunk_size=1024):
if not chunk:
break
buffer += chunk
while b'\x00' in buffer:
buffer = buffer.replace(b'\x00', b'')
yield buffer.decode('utf-8')
buffer = b''
while b"\x00" in buffer:
buffer = buffer.replace(b"\x00", b"")
yield buffer.decode("utf-8")
buffer = b""
# No more data expected, yield any remaining data in the buffer
if buffer:
yield buffer.decode('utf-8')
yield buffer.decode("utf-8")
def embedding():
# logic for parsing in - calling - parsing out model embedding calls

View file

@ -6,6 +6,7 @@ import litellm
import httpx, aiohttp, asyncio
from .prompt_templates.factory import prompt_factory, custom_prompt
class OllamaError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -16,7 +17,8 @@ class OllamaError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
class OllamaConfig():
class OllamaConfig:
"""
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
@ -56,51 +58,68 @@ class OllamaConfig():
- `template` (string): the full prompt or prompt template (overrides what is defined in the Modelfile)
"""
mirostat: Optional[int]=None
mirostat_eta: Optional[float]=None
mirostat_tau: Optional[float]=None
num_ctx: Optional[int]=None
num_gqa: Optional[int]=None
num_thread: Optional[int]=None
repeat_last_n: Optional[int]=None
repeat_penalty: Optional[float]=None
temperature: Optional[float]=None
stop: Optional[list]=None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
tfs_z: Optional[float]=None
num_predict: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
system: Optional[str]=None
template: Optional[str]=None
def __init__(self,
mirostat: Optional[int]=None,
mirostat_eta: Optional[float]=None,
mirostat_tau: Optional[float]=None,
num_ctx: Optional[int]=None,
num_gqa: Optional[int]=None,
num_thread: Optional[int]=None,
repeat_last_n: Optional[int]=None,
repeat_penalty: Optional[float]=None,
temperature: Optional[float]=None,
stop: Optional[list]=None,
tfs_z: Optional[float]=None,
num_predict: Optional[int]=None,
top_k: Optional[int]=None,
top_p: Optional[float]=None,
system: Optional[str]=None,
template: Optional[str]=None) -> None:
mirostat: Optional[int] = None
mirostat_eta: Optional[float] = None
mirostat_tau: Optional[float] = None
num_ctx: Optional[int] = None
num_gqa: Optional[int] = None
num_thread: Optional[int] = None
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
stop: Optional[
list
] = None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
tfs_z: Optional[float] = None
num_predict: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
system: Optional[str] = None
template: Optional[str] = None
def __init__(
self,
mirostat: Optional[int] = None,
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,
num_ctx: Optional[int] = None,
num_gqa: Optional[int] = None,
num_thread: Optional[int] = None,
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
temperature: Optional[float] = None,
stop: Optional[list] = None,
tfs_z: Optional[float] = None,
num_predict: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
system: Optional[str] = None,
template: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# ollama implementation
def get_ollama_response(
@ -111,36 +130,51 @@ def get_ollama_response(
logging_obj=None,
acompletion: bool = False,
model_response=None,
encoding=None
):
encoding=None,
):
if api_base.endswith("/api/generate"):
url = api_base
else:
url = f"{api_base}/api/generate"
## Load Config
config=litellm.OllamaConfig.get_config()
config = litellm.OllamaConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
optional_params["stream"] = optional_params.get("stream", False)
data = {
"model": model,
"prompt": prompt,
**optional_params
}
data = {"model": model, "prompt": prompt, **optional_params}
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,},
additional_args={
"api_base": url,
"complete_input_dict": data,
"headers": {},
"acompletion": acompletion,
},
)
if acompletion is True:
if optional_params.get("stream", False) == True:
response = ollama_async_streaming(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
response = ollama_async_streaming(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
else:
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
response = ollama_acompletion(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
return response
elif optional_params.get("stream", False) == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
@ -168,7 +202,16 @@ def get_ollama_response(
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
if optional_params.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"arguments": response_json["response"], "name": ""},
"type": "function",
}
],
)
model_response["choices"][0]["message"] = message
else:
model_response["choices"][0]["message"]["content"] = response_json["response"]
@ -176,44 +219,59 @@ def get_ollama_response(
model_response["model"] = "ollama/" + model
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response
def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
url=url,
json=data,
method="POST",
timeout=litellm.request_timeout
url=url, json=data, method="POST", timeout=litellm.request_timeout
) as response:
try:
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
raise OllamaError(
status_code=response.status_code, message=response.text
)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
try:
client = httpx.AsyncClient()
async with client.stream(
url=f"{url}",
json=data,
method="POST",
timeout=litellm.request_timeout
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
raise OllamaError(
status_code=response.status_code, message=response.text
)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.aiter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.aiter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
traceback.print_exc()
async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
data["stream"] = False
try:
@ -227,7 +285,7 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
## LOGGING
logging_obj.post_call(
input=data['prompt'],
input=data["prompt"],
api_key="",
original_response=resp.text,
additional_args={
@ -240,37 +298,59 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
if data.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": response_json["response"],
"name": "",
},
"type": "function",
}
],
)
model_response["choices"][0]["message"] = message
else:
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["choices"][0]["message"]["content"] = response_json[
"response"
]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data['model']
model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response
except Exception as e:
traceback.print_exc()
raise e
async def ollama_aembeddings(api_base="http://localhost:11434",
async def ollama_aembeddings(
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
model_response=None,
encoding=None):
encoding=None,
):
if api_base.endswith("/api/embeddings"):
url = api_base
else:
url = f"{api_base}/api/embeddings"
## Load Config
config=litellm.OllamaConfig.get_config()
config = litellm.OllamaConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
@ -308,11 +388,7 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding
}
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response["object"] = "list"
model_response["data"] = output_data

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
class OobaboogaError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -15,6 +16,7 @@ class OobaboogaError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
def validate_environment(api_key):
headers = {
"accept": "application/json",
@ -24,6 +26,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Token {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -45,7 +48,10 @@ def completion(
elif api_base:
completion_url = api_base
else:
raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')")
raise OobaboogaError(
status_code=404,
message="API Base not set. Set one via completion(..,api_base='your-api-url')",
)
model = model
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
@ -54,7 +60,7 @@ def completion(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
@ -72,7 +78,10 @@ def completion(
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
@ -89,7 +98,9 @@ def completion(
try:
completion_response = response.json()
except:
raise OobaboogaError(message=response.text, status_code=response.status_code)
raise OobaboogaError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise OobaboogaError(
message=completion_response["error"],
@ -97,14 +108,17 @@ def completion(
)
else:
try:
model_response["choices"][0]["message"]["content"] = completion_response['results'][0]['text']
model_response["choices"][0]["message"][
"content"
] = completion_response["results"][0]["text"]
except:
raise OobaboogaError(message=json.dumps(completion_response), status_code=response.status_code)
raise OobaboogaError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
)
@ -114,11 +128,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -2,15 +2,29 @@ from typing import Optional, Union, Any
import types, time, json
import httpx
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object, Usage
from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
Usage,
)
from typing import Callable, Optional
import aiohttp, requests
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI
class OpenAIError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request:
@ -20,13 +34,15 @@ class OpenAIError(Exception):
if response:
self.response = response
else:
self.response = httpx.Response(status_code=status_code, request=self.request)
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class OpenAIConfig():
class OpenAIConfig:
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -52,42 +68,56 @@ class OpenAIConfig():
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
frequency_penalty: Optional[int]=None
function_call: Optional[Union[str, dict]]=None
functions: Optional[list]=None
logit_bias: Optional[dict]=None
max_tokens: Optional[int]=None
n: Optional[int]=None
presence_penalty: Optional[int]=None
stop: Optional[Union[str, list]]=None
temperature: Optional[int]=None
top_p: Optional[int]=None
def __init__(self,
frequency_penalty: Optional[int]=None,
function_call: Optional[Union[str, dict]]=None,
functions: Optional[list]=None,
logit_bias: Optional[dict]=None,
max_tokens: Optional[int]=None,
n: Optional[int]=None,
presence_penalty: Optional[int]=None,
stop: Optional[Union[str, list]]=None,
temperature: Optional[int]=None,
top_p: Optional[int]=None,) -> None:
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class OpenAITextCompletionConfig():
class OpenAITextCompletionConfig:
"""
Reference: https://platform.openai.com/docs/api-reference/completions/create
@ -117,65 +147,80 @@ class OpenAITextCompletionConfig():
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
best_of: Optional[int]=None
echo: Optional[bool]=None
frequency_penalty: Optional[int]=None
logit_bias: Optional[dict]=None
logprobs: Optional[int]=None
max_tokens: Optional[int]=None
n: Optional[int]=None
presence_penalty: Optional[int]=None
stop: Optional[Union[str, list]]=None
suffix: Optional[str]=None
temperature: Optional[float]=None
top_p: Optional[float]=None
def __init__(self,
best_of: Optional[int]=None,
echo: Optional[bool]=None,
frequency_penalty: Optional[int]=None,
logit_bias: Optional[dict]=None,
logprobs: Optional[int]=None,
max_tokens: Optional[int]=None,
n: Optional[int]=None,
presence_penalty: Optional[int]=None,
stop: Optional[Union[str, list]]=None,
suffix: Optional[str]=None,
temperature: Optional[float]=None,
top_p: Optional[float]=None) -> None:
best_of: Optional[int] = None
echo: Optional[bool] = None
frequency_penalty: Optional[int] = None
logit_bias: Optional[dict] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
suffix: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
def __init__(
self,
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[int] = None,
logit_bias: Optional[dict] = None,
logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def completion(self,
def completion(
self,
model_response: ModelResponse,
timeout: float,
model: Optional[str]=None,
messages: Optional[list]=None,
print_verbose: Optional[Callable]=None,
api_key: Optional[str]=None,
api_base: Optional[str]=None,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
logging_obj=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict]=None,
custom_prompt_dict: dict={},
client=None
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
):
super().completion()
exception_mapping_worked = False
@ -186,36 +231,79 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(status_code=422, message=f"Missing model or messages")
if not isinstance(timeout, float):
raise OpenAIError(status_code=422, message=f"Timeout needs to be a float")
raise OpenAIError(
status_code=422, message=f"Timeout needs to be a float"
)
for _ in range(2): # if call fails due to alternating messages, retry with reformatted message
data = {
"model": model,
"messages": messages,
**optional_params
}
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
data = {"model": model, "messages": messages, **optional_params}
try:
max_retries = data.pop("max_retries", 2)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
return self.async_streaming(
logging_obj=logging_obj,
headers=headers,
data=data,
model=model,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
)
else:
return self.acompletion(data=data, headers=headers, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
return self.acompletion(
data=data,
headers=headers,
logging_obj=logging_obj,
model_response=model_response,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
)
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
return self.streaming(
logging_obj=logging_obj,
headers=headers,
data=data,
model=model,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
)
else:
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
additional_args={
"headers": headers,
"api_base": api_base,
"acompletion": acompletion,
"complete_input_dict": data,
},
)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
raise OpenAIError(
status_code=422, message="max retries must be an int"
)
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_client = client
response = openai_client.chat.completions.create(**data) # type: ignore
@ -226,16 +314,23 @@ class OpenAIChatCompletion(BaseLLM):
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
if "Conversation roles must alternate user/assistant" in str(
e
) or "user and assistant roles should be alternating" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(messages)-1):
for i in range(len(messages) - 1):
new_messages.append(messages[i])
if messages[i]["role"] == messages[i+1]["role"]:
if messages[i]["role"] == messages[i + 1]["role"]:
if messages[i]["role"] == "user":
new_messages.append({"role": "assistant", "content": ""})
new_messages.append(
{"role": "assistant", "content": ""}
)
else:
new_messages.append({"role": "user", "content": ""})
new_messages.append(messages[-1])
@ -252,118 +347,179 @@ class OpenAIChatCompletion(BaseLLM):
except Exception as e:
raise e
async def acompletion(self,
async def acompletion(
self,
data: dict,
model_response: ModelResponse,
timeout: float,
api_key: Optional[str]=None,
api_base: Optional[str]=None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
headers=None
headers=None,
):
response = None
try:
if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_aclient = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
input=data["messages"],
api_key=openai_aclient.api_key,
additional_args={"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, "api_base": openai_aclient._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
additional_args={
"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"},
"api_base": openai_aclient._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await openai_aclient.chat.completions.create(**data)
stringified_response = response.model_dump_json()
logging_obj.post_call(
input=data['messages'],
input=data["messages"],
api_key=api_key,
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
)
except Exception as e:
raise e
def streaming(self,
def streaming(
self,
logging_obj,
timeout: float,
data: dict,
model: str,
api_key: Optional[str]=None,
api_base: Optional[str]=None,
client = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
headers=None
headers=None,
):
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_client = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
input=data["messages"],
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": False, "complete_input_dict": data},
additional_args={
"headers": headers,
"api_base": api_base,
"acompletion": False,
"complete_input_dict": data,
},
)
response = openai_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
)
return streamwrapper
async def async_streaming(self,
async def async_streaming(
self,
logging_obj,
timeout: float,
data: dict,
model: str,
api_key: Optional[str]=None,
api_base: Optional[str]=None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
headers=None
headers=None,
):
response = None
try:
if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_aclient = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
input=data["messages"],
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data},
additional_args={
"headers": headers,
"api_base": api_base,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e: # need to exception handle here. async exceptions don't get caught in sync functions.
except (
Exception
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
if response is not None and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
raise OpenAIError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
)
else:
if type(e).__name__ == "ReadTimeout":
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
async def aembedding(
self,
input: list,
data: dict,
model_response: ModelResponse,
timeout: float,
api_key: Optional[str]=None,
api_base: Optional[str]=None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None
logging_obj=None,
):
response = None
try:
if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_aclient = client
response = await openai_aclient.embeddings.create(**data) # type: ignore
@ -385,7 +541,8 @@ class OpenAIChatCompletion(BaseLLM):
)
raise e
def embedding(self,
def embedding(
self,
model: str,
input: list,
timeout: float,
@ -401,11 +558,7 @@ class OpenAIChatCompletion(BaseLLM):
exception_mapping_worked = False
try:
model = model
data = {
"model": model,
"input": input,
**optional_params
}
data = {"model": model, "input": input, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
@ -420,7 +573,13 @@ class OpenAIChatCompletion(BaseLLM):
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_client = client
@ -443,6 +602,7 @@ class OpenAIChatCompletion(BaseLLM):
raise e
else:
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
async def aimage_generation(
@ -451,16 +611,22 @@ class OpenAIChatCompletion(BaseLLM):
data: dict,
model_response: ModelResponse,
timeout: float,
api_key: Optional[str]=None,
api_base: Optional[str]=None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None
logging_obj=None,
):
response = None
try:
if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_aclient = client
response = await openai_aclient.images.generate(**data) # type: ignore
@ -482,7 +648,8 @@ class OpenAIChatCompletion(BaseLLM):
)
raise e
def image_generation(self,
def image_generation(
self,
model: Optional[str],
prompt: str,
timeout: float,
@ -497,11 +664,7 @@ class OpenAIChatCompletion(BaseLLM):
exception_mapping_worked = False
try:
model = model
data = {
"model": model,
"prompt": prompt,
**optional_params
}
data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
@ -511,7 +674,13 @@ class OpenAIChatCompletion(BaseLLM):
# return response
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else:
openai_client = client
@ -519,7 +688,12 @@ class OpenAIChatCompletion(BaseLLM):
logging_obj.pre_call(
input=prompt,
api_key=openai_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {openai_client.api_key}"}, "api_base": openai_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
additional_args={
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
"api_base": openai_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
## COMPLETION CALL
@ -541,8 +715,10 @@ class OpenAIChatCompletion(BaseLLM):
raise e
else:
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
class OpenAITextCompletion(BaseLLM):
_client_session: httpx.Client
@ -558,15 +734,21 @@ class OpenAITextCompletion(BaseLLM):
headers["Authorization"] = f"Bearer {api_key}"
return headers
def convert_to_model_response_object(self, response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None):
def convert_to_model_response_object(
self,
response_object: Optional[dict] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list=[]
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["text"], role="assistant")
choice = Choices(finish_reason=choice["finish_reason"], index=idx, message=message)
choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message
)
choice_list.append(choice)
model_response_object.choices = choice_list
@ -579,24 +761,28 @@ class OpenAITextCompletion(BaseLLM):
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params["original_response"] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
model_response_object._hidden_params[
"original_response"
] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
return model_response_object
except Exception as e:
raise e
def completion(self,
def completion(
self,
model_response: ModelResponse,
api_key: str,
model: str,
messages: list,
print_verbose: Optional[Callable]=None,
api_base: Optional[str]=None,
print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None,
logging_obj=None,
acompletion: bool = False,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict]=None):
headers: Optional[dict] = None,
):
super().completion()
exception_mapping_worked = False
try:
@ -607,7 +793,11 @@ class OpenAITextCompletion(BaseLLM):
api_base = f"{api_base}/completions"
if len(messages)>0 and "content" in messages[0] and type(messages[0]["content"]) == list:
if (
len(messages) > 0
and "content" in messages[0]
and type(messages[0]["content"]) == list
):
prompt = messages[0]["content"]
else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore
@ -615,24 +805,38 @@ class OpenAITextCompletion(BaseLLM):
# don't send max retries to the api, if set
optional_params.pop("max_retries", None)
data = {
"model": model,
"prompt": prompt,
**optional_params
}
data = {"model": model, "prompt": prompt, **optional_params}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "complete_input_dict": data},
additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
)
if acompletion == True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
headers=headers,
model_response=model_response,
model=model,
)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
headers=headers,
model_response=model_response,
model=model,
)
else:
response = httpx.post(
url=f"{api_base}",
@ -640,7 +844,9 @@ class OpenAITextCompletion(BaseLLM):
headers=headers,
)
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
raise OpenAIError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
@ -654,11 +860,15 @@ class OpenAITextCompletion(BaseLLM):
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
return self.convert_to_model_response_object(
response_object=response.json(),
model_response_object=model_response,
)
except Exception as e:
raise e
async def acompletion(self,
async def acompletion(
self,
logging_obj,
api_base: str,
data: dict,
@ -666,13 +876,21 @@ class OpenAITextCompletion(BaseLLM):
model_response: ModelResponse,
prompt: str,
api_key: str,
model: str):
model: str,
):
async with httpx.AsyncClient() as client:
try:
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
response = await client.post(
api_base,
json=data,
headers=headers,
timeout=litellm.request_timeout,
)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
raise OpenAIError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
@ -686,52 +904,71 @@ class OpenAITextCompletion(BaseLLM):
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
return self.convert_to_model_response_object(
response_object=response_json, model_response_object=model_response
)
except Exception as e:
raise e
def streaming(self,
def streaming(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str
model: str,
):
with httpx.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST",
timeout=litellm.request_timeout
timeout=litellm.request_timeout,
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
raise OpenAIError(
status_code=response.status_code, message=response.text
)
streamwrapper = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
for transformed_chunk in streamwrapper:
yield transformed_chunk
async def async_streaming(self,
async def async_streaming(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str):
model: str,
):
client = httpx.AsyncClient()
async with client.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST",
timeout=litellm.request_timeout
timeout=litellm.request_timeout,
) as response:
try:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
raise OpenAIError(
status_code=response.status_code, message=response.text
)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(
completion_stream=response.aiter_lines(),
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:

View file

@ -1,30 +1,41 @@
from typing import List, Dict
import types
class OpenrouterConfig():
class OpenrouterConfig:
"""
Reference: https://openrouter.ai/docs#format
"""
# OpenRouter-only parameters
extra_body: Dict[str, List[str]] = {
'transforms': [] # default transforms to []
}
extra_body: Dict[str, List[str]] = {"transforms": []} # default transforms to []
def __init__(self,
def __init__(
self,
transforms: List[str] = [],
models: List[str] = [],
route: str = '',
route: str = "",
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}

View file

@ -7,17 +7,22 @@ from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
import litellm
import sys, httpx
class PalmError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat")
self.request = httpx.Request(
method="POST",
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class PalmConfig():
class PalmConfig:
"""
Reference: https://developers.generativeai.google/api/python/google/generativeai/chat
@ -37,35 +42,47 @@ class PalmConfig():
- `max_output_tokens` (int): Sets the maximum number of tokens to be returned in the output
"""
context: Optional[str]=None
examples: Optional[list]=None
temperature: Optional[float]=None
candidate_count: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
max_output_tokens: Optional[int]=None
def __init__(self,
context: Optional[str]=None,
examples: Optional[list]=None,
temperature: Optional[float]=None,
candidate_count: Optional[int]=None,
top_k: Optional[int]=None,
top_p: Optional[float]=None,
max_output_tokens: Optional[int]=None) -> None:
context: Optional[str] = None
examples: Optional[list] = None
temperature: Optional[float] = None
candidate_count: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
max_output_tokens: Optional[int] = None
def __init__(
self,
context: Optional[str] = None,
examples: Optional[list] = None,
temperature: Optional[float] = None,
candidate_count: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
max_output_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion(
@ -83,30 +100,32 @@ def completion(
try:
import google.generativeai as palm
except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai")
raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
palm.configure(api_key=api_key)
model = model
## Load Config
inference_params = copy.deepcopy(optional_params)
inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py
inference_params.pop(
"stream", None
) # palm does not support streaming, so we handle this by fake streaming in main.py
config = litellm.PalmConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in inference_params
): # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
@ -142,22 +161,25 @@ def completion(
message_obj = Message(content=item["output"])
else:
message_obj = Message(content=None)
choice_obj = Choices(index=idx+1, message=message_obj)
choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
traceback.print_exc()
raise PalmError(message=traceback.format_exc(), status_code=response.status_code)
raise PalmError(
message=traceback.format_exc(), status_code=response.status_code
)
try:
completion_response = model_response["choices"][0]["message"].get("content")
except:
raise PalmError(status_code=400, message=f"No response received. Original response - {response}")
raise PalmError(
status_code=400,
message=f"No response received. Original response - {response}",
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -167,11 +189,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -8,6 +8,7 @@ import litellm
from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
class PetalsError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -16,7 +17,8 @@ class PetalsError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
class PetalsConfig():
class PetalsConfig:
"""
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
@ -37,33 +39,52 @@ class PetalsConfig():
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
"""
max_length: Optional[int]=None
max_new_tokens: Optional[int]=litellm.max_tokens # petals requires max tokens to be set
do_sample: Optional[bool]=None
temperature: Optional[float]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
repetition_penalty: Optional[float]=None
def __init__(self,
max_length: Optional[int]=None,
max_new_tokens: Optional[int]=litellm.max_tokens, # petals requires max tokens to be set
do_sample: Optional[bool]=None,
temperature: Optional[float]=None,
top_k: Optional[int]=None,
top_p: Optional[float]=None,
repetition_penalty: Optional[float]=None) -> None:
max_length: Optional[int] = None
max_new_tokens: Optional[
int
] = litellm.max_tokens # petals requires max tokens to be set
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
def __init__(
self,
max_length: Optional[int] = None,
max_new_tokens: Optional[
int
] = litellm.max_tokens, # petals requires max tokens to be set
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion(
model: str,
@ -81,7 +102,9 @@ def completion(
## Load Config
config = litellm.PetalsConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
if model in litellm.custom_prompt_dict:
@ -91,7 +114,7 @@ def completion(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
@ -101,13 +124,12 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": optional_params, "api_base": api_base},
additional_args={
"complete_input_dict": optional_params,
"api_base": api_base,
},
)
data = {
"model": model,
"inputs": prompt,
**optional_params
}
data = {"model": model, "inputs": prompt, **optional_params}
## COMPLETION CALL
response = requests.post(api_base, data=data)
@ -138,7 +160,9 @@ def completion(
model = model
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, add_bos_token=False)
tokenizer = AutoTokenizer.from_pretrained(
model, use_fast=False, add_bos_token=False
)
model_obj = AutoDistributedModelForCausalLM.from_pretrained(model)
## LOGGING
@ -167,9 +191,7 @@ def completion(
if len(output_text) > 0:
model_response["choices"][0]["message"]["content"] = output_text
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content"))
)
@ -179,11 +201,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -4,9 +4,11 @@ import json
from jinja2 import Template, exceptions, Environment, meta
from typing import Optional, Any
def default_pt(messages):
return " ".join(message["content"] for message in messages)
# alpaca prompt template - for models like mythomax, etc.
def alpaca_pt(messages):
prompt = custom_prompt(
@ -19,48 +21,45 @@ def alpaca_pt(messages):
"pre_message": "### Instruction:\n",
"post_message": "\n\n",
},
"assistant": {
"pre_message": "### Response:\n",
"post_message": "\n\n"
}
"assistant": {"pre_message": "### Response:\n", "post_message": "\n\n"},
},
bos_token="<s>",
eos_token="</s>",
messages=messages
messages=messages,
)
return prompt
# Llama2 prompt template
def llama_2_chat_pt(messages):
prompt = custom_prompt(
role_dict={
"system": {
"pre_message": "[INST] <<SYS>>\n",
"post_message": "\n<</SYS>>\n [/INST]\n"
"post_message": "\n<</SYS>>\n [/INST]\n",
},
"user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348
"pre_message": "[INST] ",
"post_message": " [/INST]\n"
"post_message": " [/INST]\n",
},
"assistant": {
"post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama
}
},
},
messages=messages,
bos_token="<s>",
eos_token="</s>"
eos_token="</s>",
)
return prompt
def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
def ollama_pt(
model, messages
): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
if "instruct" in model:
prompt = custom_prompt(
role_dict={
"system": {
"pre_message": "### System:\n",
"post_message": "\n"
},
"system": {"pre_message": "### System:\n", "post_message": "\n"},
"user": {
"pre_message": "### User:\n",
"post_message": "\n",
@ -68,10 +67,10 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
"assistant": {
"pre_message": "### Response:\n",
"post_message": "\n",
}
},
},
final_prompt_value="### Response:",
messages=messages
messages=messages,
)
elif "llava" in model:
prompt = ""
@ -88,36 +87,31 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
images.append(image_url)
return {
"prompt": prompt,
"images": images
}
return {"prompt": prompt, "images": images}
else:
prompt = "".join(m["content"] if isinstance(m['content'], str) is str else "".join(m['content']) for m in messages)
prompt = "".join(
m["content"]
if isinstance(m["content"], str) is str
else "".join(m["content"])
for m in messages
)
return prompt
def mistral_instruct_pt(messages):
prompt = custom_prompt(
initial_prompt_value="<s>",
role_dict={
"system": {
"pre_message": "[INST]",
"post_message": "[/INST]"
},
"user": {
"pre_message": "[INST]",
"post_message": "[/INST]"
},
"assistant": {
"pre_message": "[INST]",
"post_message": "[/INST]"
}
"system": {"pre_message": "[INST]", "post_message": "[/INST]"},
"user": {"pre_message": "[INST]", "post_message": "[/INST]"},
"assistant": {"pre_message": "[INST]", "post_message": "[/INST]"},
},
final_prompt_value="</s>",
messages=messages
messages=messages,
)
return prompt
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages):
prompt = ""
@ -125,11 +119,16 @@ def falcon_instruct_pt(messages):
if message["role"] == "system":
prompt += message["content"]
else:
prompt += message['role']+":"+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n")
prompt += (
message["role"]
+ ":"
+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n")
)
prompt += "\n\n"
return prompt
def falcon_chat_pt(messages):
prompt = ""
for message in messages:
@ -142,6 +141,7 @@ def falcon_chat_pt(messages):
return prompt
# MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def mpt_chat_pt(messages):
prompt = ""
@ -154,6 +154,7 @@ def mpt_chat_pt(messages):
prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n"
return prompt
# WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format
def wizardcoder_pt(messages):
prompt = ""
@ -166,6 +167,7 @@ def wizardcoder_pt(messages):
prompt += "### Response:\n" + message["content"] + "\n\n"
return prompt
# Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model
def phind_codellama_pt(messages):
prompt = ""
@ -178,13 +180,17 @@ def phind_codellama_pt(messages):
prompt += "### Assistant\n" + message["content"] + "\n\n"
return prompt
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = None):
## get the tokenizer config from huggingface
bos_token = ""
eos_token = ""
if chat_template is None:
def _get_tokenizer_config(hf_model_name):
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
url = (
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
)
# Make a GET request to fetch the JSON data
response = requests.get(url)
if response.status_code == 200:
@ -193,8 +199,12 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
return {"status": "success", "tokenizer": tokenizer_config}
else:
return {"status": "failure"}
tokenizer_config = _get_tokenizer_config(model)
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
if (
tokenizer_config["status"] == "failure"
or "chat_template" not in tokenizer_config["tokenizer"]
):
raise Exception("No chat template found")
## read the bos token, eos token and chat template from the json
tokenizer_config = tokenizer_config["tokenizer"]
@ -207,7 +217,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
# Create a template object from the template text
env = Environment()
env.globals['raise_exception'] = raise_exception
env.globals["raise_exception"] = raise_exception
try:
template = env.from_string(chat_template)
except Exception as e:
@ -216,7 +226,11 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
def _is_system_in_template():
try:
# Try rendering the template with a system message
response = template.render(messages=[{"role": "system", "content": "test"}], eos_token= "<eos>", bos_token= "<bos>")
response = template.render(
messages=[{"role": "system", "content": "test"}],
eos_token="<eos>",
bos_token="<bos>",
)
return True
# This will be raised if Jinja attempts to render the system message and it can't
@ -226,36 +240,54 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
try:
# Render the template with the provided values
if _is_system_in_template():
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=messages)
rendered_text = template.render(
bos_token=bos_token, eos_token=eos_token, messages=messages
)
else:
# treat a system message as a user message, if system not in template
try:
reformatted_messages = []
for message in messages:
if message["role"] == "system":
reformatted_messages.append({"role": "user", "content": message["content"]})
reformatted_messages.append(
{"role": "user", "content": message["content"]}
)
else:
reformatted_messages.append(message)
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=reformatted_messages)
rendered_text = template.render(
bos_token=bos_token,
eos_token=eos_token,
messages=reformatted_messages,
)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(reformatted_messages)-1):
for i in range(len(reformatted_messages) - 1):
new_messages.append(reformatted_messages[i])
if reformatted_messages[i]["role"] == reformatted_messages[i+1]["role"]:
if (
reformatted_messages[i]["role"]
== reformatted_messages[i + 1]["role"]
):
if reformatted_messages[i]["role"] == "user":
new_messages.append({"role": "assistant", "content": ""})
new_messages.append(
{"role": "assistant", "content": ""}
)
else:
new_messages.append({"role": "user", "content": ""})
new_messages.append(reformatted_messages[-1])
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages)
rendered_text = template.render(
bos_token=bos_token, eos_token=eos_token, messages=new_messages
)
return rendered_text
except Exception as e:
raise Exception(f"Error rendering template - {str(e)}")
# Anthropic template
def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
def claude_2_1_pt(
messages: list,
): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
"""
Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human:
- you can't just pass a system message
@ -264,6 +296,7 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
if a system message is passed in and followed by an assistant message, insert a blank human message between them.
"""
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
@ -271,81 +304,88 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "user":
prompt += (
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
)
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
elif message["role"] == "system":
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
elif message["role"] == "assistant":
if idx > 0 and messages[idx - 1]["role"] == "system":
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message
prompt += (
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
)
prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
return prompt
### TOGETHER AI
def get_model_info(token, model):
try:
headers = {
'Authorization': f'Bearer {token}'
}
response = requests.get('https://api.together.xyz/models/info', headers=headers)
headers = {"Authorization": f"Bearer {token}"}
response = requests.get("https://api.together.xyz/models/info", headers=headers)
if response.status_code == 200:
model_info = response.json()
for m in model_info:
if m["name"].lower().strip() == model.strip():
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
return m["config"].get("prompt_format", None), m["config"].get(
"chat_template", None
)
return None, None
else:
return None, None
except Exception as e: # safely fail a prompt template request
return None, None
def format_prompt_togetherai(messages, prompt_format, chat_template):
if prompt_format is None:
return default_pt(messages)
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
human_prompt, assistant_prompt = prompt_format.split("{prompt}")
if chat_template is not None:
prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template)
prompt = hf_chat_template(
model=None, messages=messages, chat_template=chat_template
)
elif prompt_format is not None:
prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt)
prompt = custom_prompt(
role_dict={},
messages=messages,
initial_prompt_value=human_prompt,
final_prompt_value=assistant_prompt,
)
else:
prompt = default_pt(messages)
return prompt
###
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
def anthropic_pt(
messages: list,
): # format - https://docs.anthropic.com/claude/reference/complete_post
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
prompt = ""
for idx, message in enumerate(messages): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
for idx, message in enumerate(
messages
): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
if message["role"] == "user":
prompt += (
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
)
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
elif message["role"] == "system":
prompt += (
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
)
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
else:
prompt += (
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
)
if idx == 0 and message["role"] == "assistant": # ensure the prompt always starts with `\n\nHuman: `
prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
if (
idx == 0 and message["role"] == "assistant"
): # ensure the prompt always starts with `\n\nHuman: `
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
return prompt
def gemini_text_image_pt(messages: list):
"""
{
@ -367,7 +407,9 @@ def gemini_text_image_pt(messages: list):
try:
import google.generativeai as genai
except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai")
raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
prompt = ""
images = []
@ -387,26 +429,36 @@ def gemini_text_image_pt(messages: list):
content = [prompt] + images
return content
# Function call template
def function_call_prompt(messages: list, functions: list):
function_prompt = "Produce JSON OUTPUT ONLY! The following functions are available to you:"
function_prompt = (
"Produce JSON OUTPUT ONLY! The following functions are available to you:"
)
for function in functions:
function_prompt += f"""\n{function}\n"""
function_added_to_prompt = False
for message in messages:
if "system" in message["role"]:
message['content'] += f"""{function_prompt}"""
message["content"] += f"""{function_prompt}"""
function_added_to_prompt = True
if function_added_to_prompt == False:
messages.append({'role': 'system', 'content': f"""{function_prompt}"""})
messages.append({"role": "system", "content": f"""{function_prompt}"""})
return messages
# Custom prompt template
def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str="", bos_token: str="", eos_token: str=""):
def custom_prompt(
role_dict: dict,
messages: list,
initial_prompt_value: str = "",
final_prompt_value: str = "",
bos_token: str = "",
eos_token: str = "",
):
prompt = bos_token + initial_prompt_value
bos_open = True
## a bos token is at the start of a system / human message
@ -418,8 +470,16 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
prompt += bos_token
bos_open = True
pre_message_str = role_dict[role]["pre_message"] if role in role_dict and "pre_message" in role_dict[role] else ""
post_message_str = role_dict[role]["post_message"] if role in role_dict and "post_message" in role_dict[role] else ""
pre_message_str = (
role_dict[role]["pre_message"]
if role in role_dict and "pre_message" in role_dict[role]
else ""
)
post_message_str = (
role_dict[role]["post_message"]
if role in role_dict and "post_message" in role_dict[role]
else ""
)
prompt += pre_message_str + message["content"] + post_message_str
if role == "assistant":
@ -429,7 +489,13 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
prompt += final_prompt_value
return prompt
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None):
def prompt_factory(
model: str,
messages: list,
custom_llm_provider: Optional[str] = None,
api_key: Optional[str] = None,
):
original_model_name = model
model = model.lower()
if custom_llm_provider == "ollama":
@ -441,13 +507,17 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return anthropic_pt(messages=messages)
elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template)
return format_prompt_togetherai(
messages=messages, prompt_format=prompt_format, chat_template=chat_template
)
elif custom_llm_provider == "gemini":
return gemini_text_image_pt(messages=messages)
try:
if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
elif (
"tiiuae/falcon" in model
): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
if model == "tiiuae/falcon-180B-chat":
return falcon_chat_pt(messages=messages)
elif "instruct" in model:
@ -457,17 +527,26 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return mpt_chat_pt(messages=messages)
elif "codellama/codellama" in model or "togethercomputer/codellama" in model:
if "instruct" in model:
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
return llama_2_chat_pt(
messages=messages
) # https://huggingface.co/blog/codellama#conversational-instructions
elif "wizardlm/wizardcoder" in model:
return wizardcoder_pt(messages=messages)
elif "phind/phind-codellama" in model:
return phind_codellama_pt(messages=messages)
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
elif "togethercomputer/llama-2" in model and (
"instruct" in model or "chat" in model
):
return llama_2_chat_pt(messages=messages)
elif model in ["gryphe/mythomax-l2-13b", "gryphe/mythomix-l2-13b", "gryphe/mythologic-l2-13b"]:
elif model in [
"gryphe/mythomax-l2-13b",
"gryphe/mythomix-l2-13b",
"gryphe/mythologic-l2-13b",
]:
return alpaca_pt(messages=messages)
else:
return hf_chat_template(original_model_name, messages)
except Exception as e:
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
return default_pt(
messages=messages
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

View file

@ -8,17 +8,21 @@ import litellm
import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
class ReplicateError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.replicate.com/v1/deployments")
self.request = httpx.Request(
method="POST", url="https://api.replicate.com/v1/deployments"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class ReplicateConfig():
class ReplicateConfig:
"""
Reference: https://replicate.com/meta/llama-2-70b-chat/api
- `prompt` (string): The prompt to send to the model.
@ -43,42 +47,57 @@ class ReplicateConfig():
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
"""
system_prompt: Optional[str]=None
max_new_tokens: Optional[int]=None
min_new_tokens: Optional[int]=None
temperature: Optional[int]=None
top_p: Optional[int]=None
top_k: Optional[int]=None
stop_sequences: Optional[str]=None
seed: Optional[int]=None
debug: Optional[bool]=None
def __init__(self,
system_prompt: Optional[str]=None,
max_new_tokens: Optional[int]=None,
min_new_tokens: Optional[int]=None,
temperature: Optional[int]=None,
top_p: Optional[int]=None,
top_k: Optional[int]=None,
stop_sequences: Optional[str]=None,
seed: Optional[int]=None,
debug: Optional[bool]=None) -> None:
system_prompt: Optional[str] = None
max_new_tokens: Optional[int] = None
min_new_tokens: Optional[int] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
stop_sequences: Optional[str] = None
seed: Optional[int] = None
debug: Optional[bool] = None
def __init__(
self,
system_prompt: Optional[str] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
stop_sequences: Optional[str] = None,
seed: Optional[int] = None,
debug: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# Function to start a prediction and get the prediction URL
def start_prediction(version_id, input_data, api_token, api_base, logging_obj, print_verbose):
def start_prediction(
version_id, input_data, api_token, api_base, logging_obj, print_verbose
):
base_url = api_base
if "deployments" in version_id:
print_verbose("\nLiteLLM: Request to custom replicate deployment")
@ -88,7 +107,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
initial_prediction_data = {
@ -100,22 +119,31 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
logging_obj.pre_call(
input=input_data["prompt"],
api_key="",
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers, "api_base": base_url},
additional_args={
"complete_input_dict": initial_prediction_data,
"headers": headers,
"api_base": base_url,
},
)
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers)
response = requests.post(
f"{base_url}/predictions", json=initial_prediction_data, headers=headers
)
if response.status_code == 201:
response_data = response.json()
return response_data.get("urls", {}).get("get")
else:
raise ReplicateError(response.status_code, f"Failed to start prediction {response.text}")
raise ReplicateError(
response.status_code, f"Failed to start prediction {response.text}"
)
# Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose):
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
status = ""
@ -127,18 +155,22 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
if response.status_code == 200:
response_data = response.json()
if "output" in response_data:
output_string = "".join(response_data['output'])
output_string = "".join(response_data["output"])
print_verbose(f"Non-streamed output:{output_string}")
status = response_data.get('status', None)
status = response_data.get("status", None)
logs = response_data.get("logs", "")
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}, \nReplicate logs:{logs}")
raise ReplicateError(
status_code=400,
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose("Replicate: Failed to fetch prediction status and output.")
return output_string, logs
# Function to handle prediction response (streaming)
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
previous_output = ""
@ -146,7 +178,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
@ -155,20 +187,24 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
response = requests.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
status = response_data['status']
status = response_data["status"]
if "output" in response_data:
output_string = "".join(response_data['output'])
new_output = output_string[len(previous_output):]
output_string = "".join(response_data["output"])
new_output = output_string[len(previous_output) :]
print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status}
previous_output = output_string
status = response_data['status']
status = response_data["status"]
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}")
raise ReplicateError(
status_code=400, message=f"Error: {replicate_error}"
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}")
print_verbose(
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
)
# Function to extract version ID from model string
@ -178,6 +214,7 @@ def model_to_version_id(model):
return split_model[1]
return model
# Main function for prediction completion
def completion(
model: str,
@ -198,7 +235,9 @@ def completion(
## Load Config
config = litellm.ReplicateConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
system_prompt = None
@ -233,38 +272,53 @@ def completion(
input_data = {
"prompt": prompt,
"system_prompt": system_prompt,
**optional_params
**optional_params,
}
# Otherwise, use the prompt as is
else:
input_data = {
"prompt": prompt,
**optional_params
}
input_data = {"prompt": prompt, **optional_params}
## COMPLETION CALL
## Replicate Compeltion calls have 2 steps
## Step1: Start Prediction: gets a prediction url
## Step2: Poll prediction url for response
## Step2: is handled with and without streaming
model_response["created"] = int(time.time()) # for pricing this must remain right before calling api
prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj, print_verbose=print_verbose)
model_response["created"] = int(
time.time()
) # for pricing this must remain right before calling api
prediction_url = start_prediction(
version_id,
input_data,
api_key,
api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
)
print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True:
print_verbose("streaming request")
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose)
return handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
)
else:
result, logs = handle_prediction_response(prediction_url, api_key, print_verbose)
model_response["ended"] = time.time() # for pricing this must remain right after calling api
result, logs = handle_prediction_response(
prediction_url, api_key, print_verbose
)
model_response[
"ended"
] = time.time() # for pricing this must remain right after calling api
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=result,
additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, },
additional_args={
"complete_input_dict": input_data,
"logs": logs,
"api_base": prediction_url,
},
)
print_verbose(f"raw model_response: {result}")
@ -278,12 +332,14 @@ def completion(
# Calculate usage
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(model_response["choices"][0]["message"].get("content", "")))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response["model"] = "replicate/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response

View file

@ -11,41 +11,60 @@ from copy import deepcopy
import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
class SagemakerError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker")
self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class SagemakerConfig():
class SagemakerConfig:
"""
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
"""
max_new_tokens: Optional[int]=None
top_p: Optional[float]=None
temperature: Optional[float]=None
return_full_text: Optional[bool]=None
def __init__(self,
max_new_tokens: Optional[int]=None,
top_p: Optional[float]=None,
temperature: Optional[float]=None,
return_full_text: Optional[bool]=None) -> None:
max_new_tokens: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
return_full_text: Optional[bool] = None
def __init__(
self,
max_new_tokens: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
return_full_text: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
"""
SAGEMAKER AUTH Keys/Vars
@ -55,6 +74,7 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
def completion(
model: str,
messages: list,
@ -91,8 +111,8 @@ def completion(
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME") or
"us-west-2" # default to us-west-2 if user not specified
get_secret("AWS_REGION_NAME")
or "us-west-2" # default to us-west-2 if user not specified
)
client = boto3.client(
service_name="sagemaker-runtime",
@ -106,7 +126,9 @@ def completion(
## Load Config
config = litellm.SagemakerConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in inference_params
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
model = model
@ -117,7 +139,7 @@ def completion(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages
messages=messages,
)
else:
if hf_model_name is None:
@ -126,13 +148,14 @@ def completion(
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # apply regular llama2 template
hf_model_name = "meta-llama/Llama-2-7b"
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
hf_model_name = (
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
data = json.dumps({
"inputs": prompt,
"parameters": inference_params
}).encode('utf-8')
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
"utf-8"
)
## LOGGING
request_str = f"""
@ -146,7 +169,11 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str, "hf_model_name": hf_model_name},
additional_args={
"complete_input_dict": data,
"request_str": request_str,
"hf_model_name": hf_model_name,
},
)
## COMPLETION CALL
try:
@ -184,12 +211,13 @@ def completion(
model_response["choices"][0]["message"]["content"] = completion_output
except:
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -199,12 +227,14 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding(model: str,
def embedding(
model: str,
input: list,
model_response: EmbeddingResponse,
print_verbose: Callable,
@ -213,12 +243,14 @@ def embedding(model: str,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None):
logger_fn=None,
):
"""
Supports Huggingface Jumpstart embeddings like GPT-6B
"""
### BOTO3 INIT
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
@ -240,8 +272,8 @@ def embedding(model: str,
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME") or
"us-west-2" # default to us-west-2 if user not specified
get_secret("AWS_REGION_NAME")
or "us-west-2" # default to us-west-2 if user not specified
)
client = boto3.client(
service_name="sagemaker-runtime",
@ -255,13 +287,13 @@ def embedding(model: str,
## Load Config
config = litellm.SagemakerConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in inference_params
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
#### HF EMBEDDING LOGIC
data = json.dumps({
"text_inputs": input
}).encode('utf-8')
data = json.dumps({"text_inputs": input}).encode("utf-8")
## LOGGING
request_str = f"""
@ -295,7 +327,6 @@ def embedding(model: str,
original_response=response,
)
response = json.loads(response["Body"].read().decode("utf8"))
## LOGGING
logging_obj.post_call(
@ -307,20 +338,17 @@ def embedding(model: str,
print_verbose(f"raw model_response: {response}")
if "embedding" not in response:
raise SagemakerError(status_code=500, message="embedding not found in response")
embeddings = response['embedding']
embeddings = response["embedding"]
if not isinstance(embeddings, list):
raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}")
raise SagemakerError(
status_code=422, message=f"Response not in expected format - {embeddings}"
)
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding
}
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response["object"] = "list"
@ -329,8 +357,10 @@ def embedding(model: str,
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
input_tokens += len(encoding.encode(text))
model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens)
model_response["usage"] = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
return model_response

View file

@ -9,17 +9,21 @@ import httpx
from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
class TogetherAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.together.xyz/inference")
self.request = httpx.Request(
method="POST", url="https://api.together.xyz/inference"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class TogetherAIConfig():
class TogetherAIConfig:
"""
Reference: https://docs.together.ai/reference/inference
@ -39,33 +43,47 @@ class TogetherAIConfig():
- `logprobs` (int32, optional): This parameter is not described in the prompt.
"""
max_tokens: Optional[int]=None
stop: Optional[str]=None
temperature:Optional[int]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
repetition_penalty: Optional[float]=None
logprobs: Optional[int]=None
def __init__(self,
max_tokens: Optional[int]=None,
stop: Optional[str]=None,
temperature:Optional[int]=None,
top_p: Optional[float]=None,
top_k: Optional[int]=None,
repetition_penalty: Optional[float]=None,
logprobs: Optional[int]=None) -> None:
max_tokens: Optional[int] = None
stop: Optional[str] = None
temperature: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
repetition_penalty: Optional[float] = None
logprobs: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
stop: Optional[str] = None,
temperature: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
logprobs: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
@ -80,6 +98,7 @@ def validate_environment(api_key):
}
return headers
def completion(
model: str,
messages: list,
@ -99,7 +118,9 @@ def completion(
## Load Config
config = litellm.TogetherAIConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
@ -115,7 +136,12 @@ def completion(
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list
prompt = prompt_factory(
model=model,
messages=messages,
api_key=api_key,
custom_llm_provider="together_ai",
) # api key required to query together ai model list
data = {
"model": model,
@ -128,13 +154,14 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": api_base},
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if (
"stream_tokens" in optional_params
and optional_params["stream_tokens"] == True
):
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
response = requests.post(
api_base,
headers=headers,
@ -143,11 +170,7 @@ def completion(
)
return response.iter_lines()
else:
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data)
)
response = requests.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=prompt,
@ -170,30 +193,38 @@ def completion(
)
elif "error" in completion_response["output"]:
raise TogetherAIError(
message=json.dumps(completion_response["output"]), status_code=response.status_code
message=json.dumps(completion_response["output"]),
status_code=response.status_code,
)
if len(completion_response["output"]["choices"][0]["text"]) >= 0:
model_response["choices"][0]["message"]["content"] = completion_response["output"]["choices"][0]["text"]
model_response["choices"][0]["message"]["content"] = completion_response[
"output"
]["choices"][0]["text"]
## CALCULATING USAGE
print_verbose(f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}")
print_verbose(
f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
if "finish_reason" in completion_response["output"]["choices"][0]:
model_response.choices[0].finish_reason = completion_response["output"]["choices"][0]["finish_reason"]
model_response.choices[0].finish_reason = completion_response["output"][
"choices"
][0]["finish_reason"]
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm
import httpx
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url=" https://cloud.google.com/vertex-ai/")
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexAIConfig():
class VertexAIConfig:
"""
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
@ -34,28 +38,42 @@ class VertexAIConfig():
Note: Please make sure to modify the default parameters as required for your use case.
"""
temperature: Optional[float]=None
max_output_tokens: Optional[int]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
def __init__(self,
temperature: Optional[float]=None,
max_output_tokens: Optional[int]=None,
top_p: Optional[float]=None,
top_k: Optional[int]=None) -> None:
temperature: Optional[float] = None
max_output_tokens: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
def __init__(
self,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def _get_image_bytes_from_url(image_url: str) -> bytes:
try:
@ -65,7 +83,7 @@ def _get_image_bytes_from_url(image_url: str) -> bytes:
return image_bytes
except requests.exceptions.RequestException as e:
# Handle any request exceptions (e.g., connection error, timeout)
return b'' # Return an empty bytes object or handle the error as needed
return b"" # Return an empty bytes object or handle the error as needed
def _load_image_from_url(image_url: str):
@ -78,13 +96,18 @@ def _load_image_from_url(image_url: str):
Returns:
Image: The loaded image.
"""
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
image_bytes = _get_image_bytes_from_url(image_url)
return Image.from_bytes(image_bytes)
def _gemini_vision_convert_messages(
messages: list
):
def _gemini_vision_convert_messages(messages: list):
"""
Converts given messages for GPT-4 Vision to Gemini format.
@ -115,11 +138,23 @@ def _gemini_vision_convert_messages(
try:
import vertexai
except:
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`")
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
try:
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
from vertexai.preview.language_models import (
ChatModel,
CodeChatModel,
InputOutputTextPair,
)
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
# given messages for gpt-4 vision, convert them for gemini
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
@ -159,6 +194,7 @@ def _gemini_vision_convert_messages(
except Exception as e:
raise e
def completion(
model: str,
messages: list,
@ -171,22 +207,30 @@ def completion(
optional_params=None,
litellm_params=None,
logger_fn=None,
acompletion: bool=False
acompletion: bool = False,
):
try:
import vertexai
except:
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`")
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
try:
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
from vertexai.preview.language_models import (
ChatModel,
CodeChatModel,
InputOutputTextPair,
)
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
)
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
vertexai.init(
project=vertex_project, location=vertex_location
)
vertexai.init(project=vertex_project, location=vertex_location)
## Load Config
config = litellm.VertexAIConfig.get_config()
@ -202,11 +246,19 @@ def completion(
raise ValueError("safety_settings must be a list")
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
raise ValueError("safety_settings must be a list of dicts")
safety_settings=[gapic_content_types.SafetySetting(x) for x in safety_settings]
safety_settings = [
gapic_content_types.SafetySetting(x) for x in safety_settings
]
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)])
prompt = " ".join(
[
message["content"]
for message in messages
if isinstance(message["content"], str)
]
)
mode = ""
@ -240,23 +292,68 @@ def completion(
if acompletion == True: # [TODO] expand support to vertex ai chat + text models
if optional_params.get("stream", False) is True:
# async streaming
return async_streaming(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, messages=messages, print_verbose=print_verbose, **optional_params)
return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, messages=messages,print_verbose=print_verbose,**optional_params)
return async_streaming(
llm_model=llm_model,
mode=mode,
prompt=prompt,
logging_obj=logging_obj,
request_str=request_str,
model=model,
model_response=model_response,
messages=messages,
print_verbose=print_verbose,
**optional_params,
)
return async_completion(
llm_model=llm_model,
mode=mode,
prompt=prompt,
logging_obj=logging_obj,
request_str=request_str,
model=model,
model_response=model_response,
encoding=encoding,
messages=messages,
print_verbose=print_verbose,
**optional_params,
)
if mode == "":
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream)
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=stream,
)
optional_params["stream"] = True
return model_response
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings)
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
)
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
@ -268,20 +365,34 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.generate_content(
contents=content,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=True
stream=True,
)
optional_params["stream"] = True
return model_response
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
## LLM Call
response = llm_model.generate_content(
@ -293,37 +404,73 @@ def completion(
response_obj = response._raw_response
elif mode == "chat":
chat = llm_model.start_chat()
request_str+= f"chat = llm_model.start_chat()\n"
request_str += f"chat = llm_model.start_chat()\n"
if "stream" in optional_params and optional_params["stream"] == True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n"
optional_params.pop(
"stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text":
if "stream" in optional_params and optional_params["stream"] == True:
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
request_str += f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
optional_params.pop(
"stream", None
) # See note above on handling streaming for vertex ai
request_str += (
f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = llm_model.predict(prompt, **optional_params).text
## LOGGING
@ -333,36 +480,53 @@ def completion(
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
model_response["choices"][0].finish_reason = response_obj.candidates[
0
].finish_reason.name
usage = Usage(
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count)
else:
prompt_tokens = len(
encoding.encode(prompt)
total_tokens=response_obj.usage_metadata.total_token_count,
)
else:
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, encoding=None, messages = None, print_verbose = None, **optional_params):
async def async_completion(
llm_model,
mode: str,
prompt: str,
model: str,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
encoding=None,
messages=None,
print_verbose=None,
**optional_params,
):
"""
Add support for acompletion calls for gemini-pro
"""
@ -373,8 +537,17 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
# gemini-pro
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params)
)
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
@ -386,12 +559,18 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
## LLM Call
response = await llm_model._generate_content_async(
contents=content,
generation_config=GenerationConfig(**optional_params)
contents=content, generation_config=GenerationConfig(**optional_params)
)
completion_response = response.text
response_obj = response._raw_response
@ -399,14 +578,28 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
# chat-bison etc.
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(prompt, **optional_params)
completion_response = response_obj.text
elif mode == "text":
# gecko etc.
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text
@ -417,48 +610,74 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
model_response["choices"][0].finish_reason = response_obj.candidates[
0
].finish_reason.name
usage = Usage(
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count)
else:
prompt_tokens = len(
encoding.encode(prompt)
total_tokens=response_obj.usage_metadata.total_token_count,
)
else:
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, messages = None, print_verbose = None, **optional_params):
async def async_streaming(
llm_model,
mode: str,
prompt: str,
model: str,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
messages=None,
print_verbose=None,
**optional_params,
):
"""
Add support for async streaming calls for gemini-pro
"""
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
stream = optional_params.pop("stream")
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params), stream=stream
)
optional_params["stream"] = True
elif mode == "vision":
stream = optional_params.pop("stream")
@ -470,33 +689,68 @@ async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_r
content = [prompt] + images
stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = llm_model._generate_content_streaming_async(
contents=content,
generation_config=GenerationConfig(**optional_params),
stream=True
stream=True,
)
optional_params["stream"] = True
elif mode == "chat":
chat = llm_model.start_chat()
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
request_str += f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
optional_params.pop(
"stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = chat.send_message_streaming_async(prompt, **optional_params)
optional_params["stream"] = True
elif mode == "text":
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
request_str += f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
optional_params.pop(
"stream", None
) # See note above on handling streaming for vertex ai
request_str += (
f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = llm_model.predict_streaming_async(prompt, **optional_params)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="vertex_ai",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="vertex_ai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -6,7 +6,10 @@ import time, httpx
from typing import Callable, Any
from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
llm = None
class VLLMError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -17,17 +20,20 @@ class VLLMError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
# check if vllm is installed
def validate_environment(model: str):
global llm
try:
from vllm import LLM, SamplingParams # type: ignore
if llm is None:
llm = LLM(model=model)
return llm, SamplingParams
except Exception as e:
raise VLLMError(status_code=0, message=str(e))
def completion(
model: str,
messages: list,
@ -53,12 +59,11 @@ def completion(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -69,8 +74,9 @@ def completion(
if llm:
outputs = llm.generate(prompt, sampling_params)
else:
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm")
raise VLLMError(
status_code=0, message="Need to pass in a model name to initialize vllm"
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
@ -96,16 +102,14 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def batch_completions(
model: str,
messages: list,
optional_params=None,
custom_prompt_dict={}
model: str, messages: list, optional_params=None, custom_prompt_dict={}
):
"""
Example usage:
@ -150,7 +154,7 @@ def batch_completions(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=message
messages=message,
)
prompts.append(prompt)
else:
@ -161,7 +165,9 @@ def batch_completions(
if llm:
outputs = llm.generate(prompts, sampling_params)
else:
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm")
raise VLLMError(
status_code=0, message="Need to pass in a model name to initialize vllm"
)
final_outputs = []
for output in outputs:
@ -178,12 +184,13 @@ def batch_completions(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
final_outputs.append(model_response)
return final_outputs
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

File diff suppressed because it is too large Load diff

View file

@ -3,10 +3,12 @@ from typing import Optional, List, Union, Dict, Literal
from datetime import datetime
import uuid, json
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
@ -49,7 +51,8 @@ class ProxyChatCompletionRequest(LiteLLMBase):
request_timeout: Optional[int] = None
class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(LiteLLMBase):
id: Optional[str]
@ -57,21 +60,21 @@ class ModelInfoDelete(LiteLLMBase):
class ModelInfo(LiteLLMBase):
id: Optional[str]
mode: Optional[Literal['embedding', 'chat', 'completion']]
mode: Optional[Literal["embedding", "chat", "completion"]]
input_cost_per_token: Optional[float] = 0.0
output_cost_per_token: Optional[float] = 0.0
max_tokens: Optional[int] = 2048 # assume 2048 if not set
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
# we look up the base model in model_prices_and_context_window.json
base_model: Optional[Literal
[
'gpt-4-1106-preview',
'gpt-4-32k',
'gpt-4',
'gpt-3.5-turbo-16k',
'gpt-3.5-turbo',
'text-embedding-ada-002',
base_model: Optional[
Literal[
"gpt-4-1106-preview",
"gpt-4-32k",
"gpt-4",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo",
"text-embedding-ada-002",
]
]
@ -79,7 +82,6 @@ class ModelInfo(LiteLLMBase):
extra = Extra.allow # Allow extra fields
protected_namespaces = ()
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("id") is None:
@ -97,7 +99,6 @@ class ModelInfo(LiteLLMBase):
return values
class ModelParams(LiteLLMBase):
model_name: str
litellm_params: dict
@ -112,6 +113,7 @@ class ModelParams(LiteLLMBase):
values.update({"model_info": ModelInfo()})
return values
class GenerateKeyRequest(LiteLLMBase):
duration: Optional[str] = "1h"
models: Optional[list] = []
@ -122,6 +124,7 @@ class GenerateKeyRequest(LiteLLMBase):
max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
class UpdateKeyRequest(LiteLLMBase):
key: str
duration: Optional[str] = None
@ -133,10 +136,12 @@ class UpdateKeyRequest(LiteLLMBase):
max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
"""
Return the row in the db
"""
api_key: Optional[str] = None
models: list = []
aliases: dict = {}
@ -147,45 +152,84 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
duration: str = "1h"
metadata: dict = {}
class GenerateKeyResponse(LiteLLMBase):
key: str
expires: Optional[datetime]
user_id: str
class _DeleteKeyObject(LiteLLMBase):
key: str
class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject]
class NewUserRequest(GenerateKeyRequest):
max_budget: Optional[float] = None
class NewUserResponse(GenerateKeyResponse):
max_budget: Optional[float] = None
class ConfigGeneralSettings(LiteLLMBase):
"""
Documents all the fields supported by `general_settings` in config.yaml
"""
completion_model: Optional[str] = Field(None, description="proxy level default model for all chat completion calls")
use_azure_key_vault: Optional[bool] = Field(None, description="load keys from azure key vault")
master_key: Optional[str] = Field(None, description="require a key for all calls to proxy")
database_url: Optional[str] = Field(None, description="connect to a postgres db - needed for generating temporary keys + tracking spend / key")
otel: Optional[bool] = Field(None, description="[BETA] OpenTelemetry support - this might change, use with caution.")
custom_auth: Optional[str] = Field(None, description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth")
max_parallel_requests: Optional[int] = Field(None, description="maximum parallel requests for each api key")
infer_model_from_keys: Optional[bool] = Field(None, description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)")
background_health_checks: Optional[bool] = Field(None, description="run health checks in background")
health_check_interval: int = Field(300, description="background health check interval in seconds")
completion_model: Optional[str] = Field(
None, description="proxy level default model for all chat completion calls"
)
use_azure_key_vault: Optional[bool] = Field(
None, description="load keys from azure key vault"
)
master_key: Optional[str] = Field(
None, description="require a key for all calls to proxy"
)
database_url: Optional[str] = Field(
None,
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
)
otel: Optional[bool] = Field(
None,
description="[BETA] OpenTelemetry support - this might change, use with caution.",
)
custom_auth: Optional[str] = Field(
None,
description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth",
)
max_parallel_requests: Optional[int] = Field(
None, description="maximum parallel requests for each api key"
)
infer_model_from_keys: Optional[bool] = Field(
None,
description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)",
)
background_health_checks: Optional[bool] = Field(
None, description="run health checks in background"
)
health_check_interval: int = Field(
300, description="background health check interval in seconds"
)
class ConfigYAML(LiteLLMBase):
"""
Documents all the fields supported by the config.yaml
"""
model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs")
litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache")
model_list: Optional[List[ModelParams]] = Field(
None,
description="List of supported models on the server, with model-specific configs",
)
litellm_settings: Optional[dict] = Field(
None,
description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache",
)
general_settings: Optional[ConfigGeneralSettings] = None
class Config:
protected_namespaces = ()

View file

@ -4,6 +4,8 @@ from dotenv import load_dotenv
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"

View file

@ -10,12 +10,14 @@ from litellm.integrations.custom_logger import CustomLogger
import litellm
import inspect
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
class MyCustomHandler(CustomLogger):
def __init__(self):
blue_color_code = "\033[94m"
@ -23,7 +25,11 @@ class MyCustomHandler(CustomLogger):
print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger")
try:
print_verbose(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
methods = [
method
for method in dir(self)
if inspect.ismethod(getattr(self, method))
]
# Pretty print_verbose the methods
for method in methods:
@ -32,7 +38,6 @@ class MyCustomHandler(CustomLogger):
except:
pass
def log_pre_api_call(self, model, messages, kwargs):
print_verbose(f"Pre-API Call")
@ -45,7 +50,6 @@ class MyCustomHandler(CustomLogger):
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose("On Success!")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"On Async Success!")
response_cost = litellm.completion_cost(completion_response=response_obj)

View file

@ -12,25 +12,16 @@ from litellm._logging import print_verbose
logger = logging.getLogger(__name__)
ILLEGAL_DISPLAY_PARAMS = [
"messages",
"api_key"
]
ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key"]
def _get_random_llm_message():
"""
Get a random message from the LLM.
"""
messages = [
"Hey how's it going?",
"What's 1 + 1?"
]
messages = ["Hey how's it going?", "What's 1 + 1?"]
return [
{"role": "user", "content": random.choice(messages)}
]
return [{"role": "user", "content": random.choice(messages)}]
def _clean_litellm_params(litellm_params: dict):
@ -44,13 +35,16 @@ async def _perform_health_check(model_list: list):
"""
Perform a health check for each model in the list.
"""
async def _check_img_gen_model(model_params: dict):
model_params.pop("messages", None)
model_params["prompt"] = "test from litellm"
try:
await litellm.aimage_generation(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False
return True
@ -60,16 +54,19 @@ async def _perform_health_check(model_list: list):
try:
await litellm.aembedding(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False
return True
async def _check_model(model_params: dict):
try:
await litellm.acompletion(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False
return True
@ -104,9 +101,9 @@ async def _perform_health_check(model_list: list):
return healthy_endpoints, unhealthy_endpoints
async def perform_health_check(model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None):
async def perform_health_check(
model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None
):
"""
Perform a health check on the system.
@ -115,7 +112,9 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl
"""
if not model_list:
if cli_model:
model_list = [{"model_name": cli_model, "litellm_params": {"model": cli_model}}]
model_list = [
{"model_name": cli_model, "litellm_params": {"model": cli_model}}
]
else:
return [], []
@ -125,5 +124,3 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)
return healthy_endpoints, unhealthy_endpoints

View file

@ -6,6 +6,7 @@ from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
import json, traceback
class MaxBudgetLimiter(CustomLogger):
# Class variables or attributes
def __init__(self):
@ -15,7 +16,13 @@ class MaxBudgetLimiter(CustomLogger):
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"

View file

@ -5,8 +5,10 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
class MaxParallelRequestsHandler(CustomLogger):
user_api_key_cache = None
# Class variables or attributes
def __init__(self):
pass
@ -15,8 +17,13 @@ class MaxParallelRequestsHandler(CustomLogger):
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests
@ -39,8 +46,9 @@ class MaxParallelRequestsHandler(CustomLogger):
# Increase count for this token
cache.set_cache(request_count_api_key, int(current) + 1)
else:
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
@ -54,17 +62,23 @@ class MaxParallelRequestsHandler(CustomLogger):
request_count_api_key = f"{user_api_key}_request_count"
# check if it has collected an entire stream response
self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}")
self.print_verbose(
f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}"
)
if "complete_streaming_response" in kwargs or kwargs["stream"] != True:
# Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
current = (
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
)
new_val = current - 1
self.print_verbose(f"updated_value in success call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e:
self.print_verbose(e) # noqa
async def async_log_failure_call(self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception):
async def async_log_failure_call(
self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception
):
try:
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
api_key = user_api_key_dict.api_key
@ -75,14 +89,18 @@ class MaxParallelRequestsHandler(CustomLogger):
return
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
if (
hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
and "Max parallel request limit reached" in str(original_exception)
):
pass # ignore failed calls due to max limit being reached
else:
request_count_api_key = f"{api_key}_request_count"
# Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
current = (
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
)
new_val = current - 1
self.print_verbose(f"updated_value in failure call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val)

View file

@ -6,34 +6,42 @@ from datetime import datetime
import importlib
from dotenv import load_dotenv
import operator
sys.path.append(os.getcwd())
config_filename = "litellm.secrets"
# Using appdirs to determine user-specific config path
config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
user_config_path = os.getenv(
"LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)
)
load_dotenv()
from importlib import resources
import shutil
telemetry = None
def run_ollama_serve():
try:
command = ['ollama', 'serve']
command = ["ollama", "serve"]
with open(os.devnull, 'w') as devnull:
with open(os.devnull, "w") as devnull:
process = subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e:
print(f"""
print(
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""") # noqa
"""
) # noqa
def clone_subfolder(repo_url, subfolder, destination):
# Clone the full repo
repo_name = repo_url.split('/')[-1]
repo_name = repo_url.split("/")[-1]
repo_master = os.path.join(destination, "repo_master")
subprocess.run(['git', 'clone', repo_url, repo_master])
subprocess.run(["git", "clone", repo_url, repo_master])
# Move into the subfolder
subfolder_path = os.path.join(repo_master, subfolder)
@ -48,43 +56,152 @@ def clone_subfolder(repo_url, subfolder, destination):
shutil.copytree(source, dest_path)
# Remove cloned repo folder
subprocess.run(['rm', '-rf', os.path.join(destination, "repo_master")])
subprocess.run(["rm", "-rf", os.path.join(destination, "repo_master")])
feature_telemetry(feature="create-proxy")
def is_port_in_use(port):
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
return s.connect_ex(("localhost", port)) == 0
@click.command()
@click.option('--host', default='0.0.0.0', help='Host for the server to listen on.')
@click.option('--port', default=8000, help='Port to bind the server to.')
@click.option('--num_workers', default=1, help='Number of uvicorn workers to spin up')
@click.option('--api_base', default=None, help='API base URL.')
@click.option('--api_version', default="2023-07-01-preview", help='For azure - pass in the api version.')
@click.option('--model', '-m', default=None, help='The model name to pass to litellm expects')
@click.option('--alias', default=None, help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")')
@click.option('--add_key', default=None, help='The model name to pass to litellm expects')
@click.option('--headers', default=None, help='headers for the API call')
@click.option('--save', is_flag=True, type=bool, help='Save the model-specific config')
@click.option('--debug', default=False, is_flag=True, type=bool, help='To debug the input')
@click.option('--use_queue', default=False, is_flag=True, type=bool, help='To use celery workers for async endpoints')
@click.option('--temperature', default=None, type=float, help='Set temperature for the model')
@click.option('--max_tokens', default=None, type=int, help='Set max tokens for the model')
@click.option('--request_timeout', default=600, type=int, help='Set timeout in seconds for completion calls')
@click.option('--drop_params', is_flag=True, help='Drop any unmapped params')
@click.option('--add_function_to_prompt', is_flag=True, help='If function passed but unsupported, pass it as prompt')
@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`')
@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`')
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
@click.option('--version', '-v', default=False, is_flag=True, type=bool, help='Print LiteLLM version')
@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.')
@click.option('--health', flag_value=True, help='Make a chat/completions request to all llms in config.yaml')
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to')
@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response')
@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with')
@click.option('--local', is_flag=True, default=False, help='for local debugging')
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health, version):
@click.option("--host", default="0.0.0.0", help="Host for the server to listen on.")
@click.option("--port", default=8000, help="Port to bind the server to.")
@click.option("--num_workers", default=1, help="Number of uvicorn workers to spin up")
@click.option("--api_base", default=None, help="API base URL.")
@click.option(
"--api_version",
default="2023-07-01-preview",
help="For azure - pass in the api version.",
)
@click.option(
"--model", "-m", default=None, help="The model name to pass to litellm expects"
)
@click.option(
"--alias",
default=None,
help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")',
)
@click.option(
"--add_key", default=None, help="The model name to pass to litellm expects"
)
@click.option("--headers", default=None, help="headers for the API call")
@click.option("--save", is_flag=True, type=bool, help="Save the model-specific config")
@click.option(
"--debug", default=False, is_flag=True, type=bool, help="To debug the input"
)
@click.option(
"--use_queue",
default=False,
is_flag=True,
type=bool,
help="To use celery workers for async endpoints",
)
@click.option(
"--temperature", default=None, type=float, help="Set temperature for the model"
)
@click.option(
"--max_tokens", default=None, type=int, help="Set max tokens for the model"
)
@click.option(
"--request_timeout",
default=600,
type=int,
help="Set timeout in seconds for completion calls",
)
@click.option("--drop_params", is_flag=True, help="Drop any unmapped params")
@click.option(
"--add_function_to_prompt",
is_flag=True,
help="If function passed but unsupported, pass it as prompt",
)
@click.option(
"--config",
"-c",
default=None,
help="Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`",
)
@click.option(
"--max_budget",
default=None,
type=float,
help="Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`",
)
@click.option(
"--telemetry",
default=True,
type=bool,
help="Helps us know if people are using this feature. Turn this off by doing `--telemetry False`",
)
@click.option(
"--version",
"-v",
default=False,
is_flag=True,
type=bool,
help="Print LiteLLM version",
)
@click.option(
"--logs",
flag_value=False,
type=int,
help='Gets the "n" most recent logs. By default gets most recent log.',
)
@click.option(
"--health",
flag_value=True,
help="Make a chat/completions request to all llms in config.yaml",
)
@click.option(
"--test",
flag_value=True,
help="proxy chat completions url to make a test request to",
)
@click.option(
"--test_async",
default=False,
is_flag=True,
help="Calls async endpoints /queue/requests and /queue/response",
)
@click.option(
"--num_requests",
default=10,
type=int,
help="Number of requests to hit async endpoint with",
)
@click.option("--local", is_flag=True, default=False, help="for local debugging")
def run_server(
host,
port,
api_base,
api_version,
model,
alias,
add_key,
headers,
save,
debug,
temperature,
max_tokens,
request_timeout,
drop_params,
add_function_to_prompt,
config,
max_budget,
telemetry,
logs,
test,
local,
num_workers,
test_async,
num_requests,
use_queue,
health,
version,
):
global feature_telemetry
args = locals()
if local:
@ -99,17 +216,23 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
if logs == 0: # default to 1
logs = 1
try:
with open('api_log.json') as f:
with open("api_log.json") as f:
data = json.load(f)
# convert keys to datetime objects
log_times = {datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items()}
log_times = {
datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items()
}
# sort by timestamp
sorted_times = sorted(log_times.items(), key=operator.itemgetter(0), reverse=True)
sorted_times = sorted(
log_times.items(), key=operator.itemgetter(0), reverse=True
)
# get n recent logs
recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]}
recent_logs = {
k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]
}
print(json.dumps(recent_logs, indent=4)) # noqa
except:
@ -117,18 +240,21 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
return
if version == True:
pkg_version = importlib.metadata.version("litellm")
click.echo(f'\nLiteLLM: Current Version = {pkg_version}\n')
click.echo(f"\nLiteLLM: Current Version = {pkg_version}\n")
return
if model and "ollama" in model and api_base is None:
run_ollama_serve()
if test_async is True:
import requests, concurrent, time
api_base = f"http://{host}:{port}"
def _make_openai_completion():
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Write a short poem about the moon"}]
"messages": [
{"role": "user", "content": "Write a short poem about the moon"}
],
}
response = requests.post("http://0.0.0.0:8000/queue/request", json=data)
@ -146,7 +272,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
if status == "finished":
llm_response = polling_response["result"]
break
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa
print(
f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}"
) # noqa
time.sleep(0.5)
except Exception as e:
print("got exception in polling", e)
@ -159,7 +287,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
futures = []
start_time = time.time()
# Make concurrent calls
with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_calls) as executor:
with concurrent.futures.ThreadPoolExecutor(
max_workers=concurrent_calls
) as executor:
for _ in range(concurrent_calls):
futures.append(executor.submit(_make_openai_completion))
@ -185,58 +315,86 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
return
if health != False:
import requests
print("\nLiteLLM: Health Testing models in config")
response = requests.get(url=f"http://{host}:{port}/health")
print(json.dumps(response.json(), indent=4))
return
if test != False:
click.echo('\nLiteLLM: Making a test ChatCompletions request to your proxy')
click.echo("\nLiteLLM: Making a test ChatCompletions request to your proxy")
import openai
if test == True: # flag value set
api_base = f"http://{host}:{port}"
else:
api_base = test
client = openai.OpenAI(
api_key="My API Key",
base_url=api_base
)
client = openai.OpenAI(api_key="My API Key", base_url=api_base)
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "this is a test request, write a short poem"
"content": "this is a test request, write a short poem",
}
], max_tokens=256)
click.echo(f'\nLiteLLM: response from proxy {response}')
],
max_tokens=256,
)
click.echo(f"\nLiteLLM: response from proxy {response}")
print("\n Making streaming request to proxy")
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "this is a test request, write a short poem"
"content": "this is a test request, write a short poem",
}
],
stream=True,
)
for chunk in response:
click.echo(f'LiteLLM: streaming response from proxy {chunk}')
click.echo(f"LiteLLM: streaming response from proxy {chunk}")
print("\n making completion request to proxy")
response = client.completions.create(model="gpt-3.5-turbo", prompt='this is a test request, write a short poem')
response = client.completions.create(
model="gpt-3.5-turbo", prompt="this is a test request, write a short poem"
)
print(response)
return
else:
if headers:
headers = json.loads(headers)
save_worker_config(model=model, alias=alias, api_base=api_base, api_version=api_version, debug=debug, temperature=temperature, max_tokens=max_tokens, request_timeout=request_timeout, max_budget=max_budget, telemetry=telemetry, drop_params=drop_params, add_function_to_prompt=add_function_to_prompt, headers=headers, save=save, config=config, use_queue=use_queue)
save_worker_config(
model=model,
alias=alias,
api_base=api_base,
api_version=api_version,
debug=debug,
temperature=temperature,
max_tokens=max_tokens,
request_timeout=request_timeout,
max_budget=max_budget,
telemetry=telemetry,
drop_params=drop_params,
add_function_to_prompt=add_function_to_prompt,
headers=headers,
save=save,
config=config,
use_queue=use_queue,
)
try:
import uvicorn
except:
raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`")
raise ImportError(
"Uvicorn needs to be imported. Run - `pip install uvicorn`"
)
if port == 8000 and is_port_in_use(port):
port = random.randint(1024, 49152)
uvicorn.run("litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers)
uvicorn.run(
"litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers
)
if __name__ == "__main__":

File diff suppressed because it is too large Load diff

View file

@ -1,71 +1,77 @@
from dotenv import load_dotenv
load_dotenv()
import json, subprocess
import psutil # Import the psutil library
import atexit
try:
### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both
from celery import Celery
import redis
except:
import sys
# from dotenv import load_dotenv
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"redis",
"celery"
]
)
# load_dotenv()
# import json, subprocess
# import psutil # Import the psutil library
# import atexit
import time
import sys, os
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path - for litellm local dev
import litellm
# try:
# ### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both
# from celery import Celery
# import redis
# except:
# import sys
# Redis connection setup
pool = redis.ConnectionPool(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"), db=0, max_connections=5)
redis_client = redis.Redis(connection_pool=pool)
# subprocess.check_call([sys.executable, "-m", "pip", "install", "redis", "celery"])
# Celery setup
celery_app = Celery('tasks', broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}", backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}")
celery_app.conf.update(
broker_pool_limit = None,
broker_transport_options = {'connection_pool': pool},
result_backend_transport_options = {'connection_pool': pool},
)
# import time
# import sys, os
# sys.path.insert(
# 0, os.path.abspath("../../..")
# ) # Adds the parent directory to the system path - for litellm local dev
# import litellm
# # Redis connection setup
# pool = redis.ConnectionPool(
# host=os.getenv("REDIS_HOST"),
# port=os.getenv("REDIS_PORT"),
# password=os.getenv("REDIS_PASSWORD"),
# db=0,
# max_connections=5,
# )
# redis_client = redis.Redis(connection_pool=pool)
# # Celery setup
# celery_app = Celery(
# "tasks",
# broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}",
# backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}",
# )
# celery_app.conf.update(
# broker_pool_limit=None,
# broker_transport_options={"connection_pool": pool},
# result_backend_transport_options={"connection_pool": pool},
# )
# Celery task
@celery_app.task(name='process_job', max_retries=3)
def process_job(*args, **kwargs):
try:
llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore
response = llm_router.completion(*args, **kwargs) # type: ignore
if isinstance(response, litellm.ModelResponse):
response = response.model_dump_json()
return json.loads(response)
return str(response)
except Exception as e:
raise e
# # Celery task
# @celery_app.task(name="process_job", max_retries=3)
# def process_job(*args, **kwargs):
# try:
# llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore
# response = llm_router.completion(*args, **kwargs) # type: ignore
# if isinstance(response, litellm.ModelResponse):
# response = response.model_dump_json()
# return json.loads(response)
# return str(response)
# except Exception as e:
# raise e
# Ensure Celery workers are terminated when the script exits
def cleanup():
try:
# Get a list of all running processes
for process in psutil.process_iter(attrs=['pid', 'name']):
# Check if the process is a Celery worker process
if process.info['name'] == 'celery':
print(f"Terminating Celery worker with PID {process.info['pid']}")
# Terminate the Celery worker process
psutil.Process(process.info['pid']).terminate()
except Exception as e:
print(f"Error during cleanup: {e}")
# Register the cleanup function to run when the script exits
atexit.register(cleanup)
# # Ensure Celery workers are terminated when the script exits
# def cleanup():
# try:
# # Get a list of all running processes
# for process in psutil.process_iter(attrs=["pid", "name"]):
# # Check if the process is a Celery worker process
# if process.info["name"] == "celery":
# print(f"Terminating Celery worker with PID {process.info['pid']}")
# # Terminate the Celery worker process
# psutil.Process(process.info["pid"]).terminate()
# except Exception as e:
# print(f"Error during cleanup: {e}")
# # Register the cleanup function to run when the script exits
# atexit.register(cleanup)

View file

@ -1,12 +1,15 @@
import os
from multiprocessing import Process
def run_worker(cwd):
os.chdir(cwd)
os.system("celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info")
os.system(
"celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info"
)
def start_worker(cwd):
cwd += "/queue"
worker_process = Process(target=run_worker, args=(cwd,))
worker_process.start()

View file

@ -1,26 +1,34 @@
import sys, os
from dotenv import load_dotenv
load_dotenv()
# Add the path to the local folder to sys.path
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path - for litellm local dev
# import sys, os
# from dotenv import load_dotenv
def start_rq_worker():
from rq import Worker, Queue, Connection
from redis import Redis
# Set up RQ connection
redis_conn = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
print(redis_conn.ping()) # Should print True if connected successfully
# Create a worker and add the queue
try:
queue = Queue(connection=redis_conn)
worker = Worker([queue], connection=redis_conn)
except Exception as e:
print(f"Error setting up worker: {e}")
exit()
# load_dotenv()
# # Add the path to the local folder to sys.path
# sys.path.insert(
# 0, os.path.abspath("../../..")
# ) # Adds the parent directory to the system path - for litellm local dev
with Connection(redis_conn):
worker.work()
start_rq_worker()
# def start_rq_worker():
# from rq import Worker, Queue, Connection
# from redis import Redis
# # Set up RQ connection
# redis_conn = Redis(
# host=os.getenv("REDIS_HOST"),
# port=os.getenv("REDIS_PORT"),
# password=os.getenv("REDIS_PASSWORD"),
# )
# print(redis_conn.ping()) # Should print True if connected successfully
# # Create a worker and add the queue
# try:
# queue = Queue(connection=redis_conn)
# worker = Worker([queue], connection=redis_conn)
# except Exception as e:
# print(f"Error setting up worker: {e}")
# exit()
# with Connection(redis_conn):
# worker.work()
# start_rq_worker()

View file

@ -4,10 +4,7 @@ import uuid
import traceback
litellm_client = AsyncOpenAI(
api_key="test",
base_url="http://0.0.0.0:8000"
)
litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000")
async def litellm_completion():
@ -15,9 +12,10 @@ async def litellm_completion():
try:
response = await litellm_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"*180}], # this is about 4k tokens per request
messages=[
{"role": "user", "content": f"This is a test: {uuid.uuid4()}" * 180}
], # this is about 4k tokens per request
)
print(response)
return response
except Exception as e:
@ -27,7 +25,6 @@ async def litellm_completion():
pass
async def main():
start = time.time()
n = 60 # Send 60 concurrent requests, each with 4k tokens = 240k Tokens
@ -45,6 +42,7 @@ async def main():
print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__":
# Blank out contents of error_log.txt
open("error_log.txt", "w").close()

View file

@ -4,10 +4,7 @@ import uuid
import traceback
litellm_client = AsyncOpenAI(
api_key="sk-1234",
base_url="http://0.0.0.0:8000"
)
litellm_client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:8000")
async def litellm_completion():
@ -26,7 +23,6 @@ async def litellm_completion():
pass
async def main():
start = time.time()
n = 1000 # Number of concurrent tasks
@ -44,6 +40,7 @@ async def main():
print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__":
# Blank out contents of error_log.txt
open("error_log.txt", "w").close()

View file

@ -14,8 +14,8 @@ import pytest
import litellm
litellm.set_verbose=False
litellm.set_verbose = False
question = "embed this very long text" * 100
@ -35,7 +35,10 @@ def make_openai_completion(question):
try:
start_time = time.time()
import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY']) #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"]
) # base_url="http://0.0.0.0:8000",
response = client.embeddings.create(
model="text-embedding-ada-002",
input=[question],
@ -58,6 +61,7 @@ def make_openai_completion(question):
# )
return None
start_time = time.time()
# Number of concurrent calls (you can adjust this)
concurrent_calls = 500

View file

@ -4,11 +4,7 @@ import uuid
import traceback
litellm_client = AsyncOpenAI(
api_key="test",
base_url="http://0.0.0.0:8000"
)
litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000")
async def litellm_completion():
@ -17,11 +13,11 @@ async def litellm_completion():
print("starting embedding calls")
response = await litellm_client.embeddings.create(
model="text-embedding-ada-002",
input = [
input=[
"hello who are you" * 2000,
"hello who are you tomorrow 1234" * 1000,
"hello who are you tomorrow 1234" * 1000
]
"hello who are you tomorrow 1234" * 1000,
],
)
print(response)
return response
@ -33,7 +29,6 @@ async def litellm_completion():
pass
async def main():
start = time.time()
n = 100 # Number of concurrent tasks
@ -51,6 +46,7 @@ async def main():
print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__":
# Blank out contents of error_log.txt
open("error_log.txt", "w").close()

View file

@ -14,8 +14,8 @@ import pytest
import litellm
litellm.set_verbose=False
litellm.set_verbose = False
question = "embed this very long text" * 100
@ -35,7 +35,10 @@ def make_openai_completion(question):
try:
start_time = time.time()
import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000"
) # base_url="http://0.0.0.0:8000",
response = client.embeddings.create(
model="text-embedding-ada-002",
input=[question],
@ -58,6 +61,7 @@ def make_openai_completion(question):
# )
return None
start_time = time.time()
# Number of concurrent calls (you can adjust this)
concurrent_calls = 500

View file

@ -2,6 +2,7 @@ import requests
import time
import os
from dotenv import load_dotenv
load_dotenv()
@ -17,32 +18,30 @@ config = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.environ['OPENAI_API_KEY'],
}
"api_key": os.environ["OPENAI_API_KEY"],
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.environ['AZURE_API_KEY'],
"api_key": os.environ["AZURE_API_KEY"],
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"api_version": "2023-07-01-preview"
}
}
"api_version": "2023-07-01-preview",
},
},
]
}
print("STARTING LOAD TEST Q")
print(os.environ['AZURE_API_KEY'])
print(os.environ["AZURE_API_KEY"])
response = requests.post(
url=f"{base_url}/key/generate",
json={
"config": config,
"duration": "30d" # default to 30d, set it to 30m if you want a temp key
"duration": "30d", # default to 30d, set it to 30m if you want a temp key
},
headers={
"Authorization": "Bearer sk-hosted-litellm"
}
headers={"Authorization": "Bearer sk-hosted-litellm"},
)
print("\nresponse from generating key", response.text)
@ -56,19 +55,18 @@ print("\ngenerated key for proxy", generated_key)
import concurrent.futures
def create_job_and_poll(request_num):
print(f"Creating a job on the proxy for request {request_num}")
job_response = requests.post(
url=f"{base_url}/queue/request",
json={
'model': 'gpt-3.5-turbo',
'messages': [
{'role': 'system', 'content': 'write a short poem'},
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "write a short poem"},
],
},
headers={
"Authorization": f"Bearer {generated_key}"
}
headers={"Authorization": f"Bearer {generated_key}"},
)
print(job_response.status_code)
print(job_response.text)
@ -84,12 +82,12 @@ def create_job_and_poll(request_num):
try:
print(f"\nPolling URL for request {request_num}", polling_url)
polling_response = requests.get(
url=polling_url,
headers={
"Authorization": f"Bearer {generated_key}"
}
url=polling_url, headers={"Authorization": f"Bearer {generated_key}"}
)
print(
f"\nResponse from polling url for request {request_num}",
polling_response.text,
)
print(f"\nResponse from polling url for request {request_num}", polling_response.text)
polling_response = polling_response.json()
status = polling_response.get("status", None)
if status == "finished":
@ -109,6 +107,7 @@ def create_job_and_poll(request_num):
except Exception as e:
print("got exception when polling", e)
# Number of requests
num_requests = 100

View file

@ -26,4 +26,3 @@
# import asyncio
# asyncio.run(test_async_completion())

View file

@ -34,7 +34,3 @@
# response = claude_chat(messages)
# print(response)

View file

@ -2,6 +2,7 @@ import requests
import time
import os
from dotenv import load_dotenv
load_dotenv()
@ -17,8 +18,8 @@ config = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.environ['OPENAI_API_KEY'],
}
"api_key": os.environ["OPENAI_API_KEY"],
},
}
]
}
@ -27,11 +28,9 @@ response = requests.post(
url=f"{base_url}/key/generate",
json={
"config": config,
"duration": "30d" # default to 30d, set it to 30m if you want a temp key
"duration": "30d", # default to 30d, set it to 30m if you want a temp key
},
headers={
"Authorization": "Bearer sk-hosted-litellm"
}
headers={"Authorization": "Bearer sk-hosted-litellm"},
)
print("\nresponse from generating key", response.text)
@ -45,14 +44,15 @@ print("Creating a job on the proxy")
job_response = requests.post(
url=f"{base_url}/queue/request",
json={
'model': 'gpt-3.5-turbo',
'messages': [
{'role': 'system', 'content': f'You are a helpful assistant. What is your name'},
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": f"You are a helpful assistant. What is your name",
},
],
},
headers={
"Authorization": f"Bearer {generated_key}"
}
headers={"Authorization": f"Bearer {generated_key}"},
)
print(job_response.status_code)
print(job_response.text)
@ -68,10 +68,7 @@ while True:
try:
print("\nPolling URL", polling_url)
polling_response = requests.get(
url=polling_url,
headers={
"Authorization": f"Bearer {generated_key}"
}
url=polling_url, headers={"Authorization": f"Bearer {generated_key}"}
)
print("\nResponse from polling url", polling_response.text)
polling_response = polling_response.json()

View file

@ -8,9 +8,12 @@ from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException, status
def print_verbose(print_statement):
if litellm.set_verbose:
print(f"LiteLLM Proxy: {print_statement}") # noqa
### LOGGING ###
class ProxyLogging:
"""
@ -57,11 +60,14 @@ class ProxyLogging:
+ litellm.failure_callback
)
)
litellm.utils.set_callbacks(
callback_list=callback_list
)
litellm.utils.set_callbacks(callback_list=callback_list)
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["completion", "embeddings"],
):
"""
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
@ -71,12 +77,19 @@ class ProxyLogging:
"""
try:
for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and 'async_pre_call_hook' in vars(callback.__class__):
response = await callback.async_pre_call_hook(user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type)
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
callback.__class__
):
response = await callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"],
data=data,
call_type=call_type,
)
if response is not None:
data = response
print_verbose(f'final data being sent to {call_type} call: {data}')
print_verbose(f"final data being sent to {call_type} call: {data}")
return data
except Exception as e:
raise e
@ -96,7 +109,9 @@ class ProxyLogging:
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
async def post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
"""
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
@ -108,7 +123,10 @@ class ProxyLogging:
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
await callback.async_post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=original_exception)
await callback.async_post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=original_exception,
)
except Exception as e:
raise e
return
@ -121,9 +139,12 @@ def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries
print_verbose(f"Backing off... this was attempt #{details['tries']}")
class PrismaClient:
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
print_verbose(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
## init logging object
self.proxy_logging_obj = proxy_logging_obj
@ -136,15 +157,16 @@ class PrismaClient:
os.chdir(dname)
try:
subprocess.run(['prisma', 'generate'])
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
subprocess.run(["prisma", "generate"])
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"]
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
finally:
os.chdir(original_dir)
# Now you can import the Prisma Client
from prisma import Client # type: ignore
self.db = Client() #Client to connect to Prisma db
self.db = Client() # Client to connect to Prisma db
def hash_token(self, token: str):
# Hash the string using SHA-256
@ -167,7 +189,12 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def get_data(self, token: Optional[str]=None, expires: Optional[Any]=None, user_id: Optional[str]=None):
async def get_data(
self,
token: Optional[str] = None,
expires: Optional[Any] = None,
user_id: Optional[str] = None,
):
try:
response = None
if token is not None:
@ -176,9 +203,7 @@ class PrismaClient:
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token
}
where={"token": hashed_token}
)
if response:
# Token exists, now check expiration.
@ -188,11 +213,17 @@ class PrismaClient:
return response
else:
# Token exists but is expired.
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="expired user key")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="expired user key",
)
return response
else:
# Token does not exist.
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid user key",
)
elif user_id is not None:
response = await self.db.litellm_usertable.find_unique( # type: ignore
where={
@ -201,7 +232,9 @@ class PrismaClient:
)
return response
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@ -224,26 +257,26 @@ class PrismaClient:
max_budget = db_data.pop("max_budget", None)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
"token": hashed_token,
},
data={
"create": {**db_data}, #type: ignore
"update": {} # don't do anything if it already exists
}
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
new_user_row = await self.db.litellm_usertable.upsert(
where={
'user_id': data['user_id']
},
where={"user_id": data["user_id"]},
data={
"create": {"user_id": data['user_id'], "max_budget": max_budget},
"update": {} # don't do anything if it already exists
}
"create": {"user_id": data["user_id"], "max_budget": max_budget},
"update": {}, # don't do anything if it already exists
},
)
return new_verification_token
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@ -254,7 +287,12 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def update_data(self, token: Optional[str]=None, data: dict={}, user_id: Optional[str]=None):
async def update_data(
self,
token: Optional[str] = None,
data: dict = {},
user_id: Optional[str] = None,
):
"""
Update existing data
"""
@ -267,10 +305,8 @@ class PrismaClient:
token = self.hash_token(token=token)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={
"token": token # type: ignore
},
data={**db_data} # type: ignore
where={"token": token}, # type: ignore
data={**db_data}, # type: ignore
)
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": db_data}
@ -279,18 +315,17 @@ class PrismaClient:
If data['spend'] + data['user'], update the user table with spend info as well
"""
update_user_row = await self.db.litellm_usertable.update(
where={
'user_id': user_id # type: ignore
},
data={**db_data} # type: ignore
where={"user_id": user_id}, # type: ignore
data={**db_data}, # type: ignore
)
return {"user_id": user_id, "data": db_data}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
@ -310,7 +345,9 @@ class PrismaClient:
)
return {"deleted_keys": tokens}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@ -325,7 +362,9 @@ class PrismaClient:
try:
await self.db.connect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@ -340,9 +379,12 @@ class PrismaClient:
try:
await self.db.disconnect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try:
@ -357,12 +399,14 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
# If config_file_path is provided, use it to determine the module spec and load the module
if config_file_path is not None:
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, *module_name.split('.'))
module_file_path += '.py'
module_file_path = os.path.join(directory, *module_name.split("."))
module_file_path += ".py"
spec = importlib.util.spec_from_file_location(module_name, module_file_path)
if spec is None:
raise ImportError(f"Could not find a module specification for {module_file_path}")
raise ImportError(
f"Could not find a module specification for {module_file_path}"
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
else:
@ -379,6 +423,7 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
except Exception as e:
raise e
### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
"""
@ -390,5 +435,7 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
if response is None: # Cache miss
user_row = await db.get_data(user_id=user_id)
cache_value = user_row.model_dump_json()
cache.set_cache(key=cache_key, value=cache_value, ttl=600) # store for 10 minutes
cache.set_cache(
key=cache_key, value=cache_value, ttl=600
) # store for 10 minutes
return

File diff suppressed because it is too large Load diff

View file

@ -8,18 +8,18 @@
import dotenv, os, requests
from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
class LeastBusyLoggingHandler(CustomLogger):
class LeastBusyLoggingHandler(CustomLogger):
def __init__(self, router_cache: DualCache):
self.router_cache = router_cache
self.mapping_deployment_to_id: dict = {}
def log_pre_api_call(self, model, messages, kwargs):
"""
Log when a model is being used.
@ -27,13 +27,16 @@ class LeastBusyLoggingHandler(CustomLogger):
Caching based on model group.
"""
try:
if kwargs['litellm_params'].get('metadata') is None:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
id = kwargs['litellm_params'].get('model_info', {}).get('id', None)
deployment = kwargs["litellm_params"]["metadata"].get(
"deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if deployment is None or model_group is None or id is None:
return
@ -42,53 +45,75 @@ class LeastBusyLoggingHandler(CustomLogger):
request_count_api_key = f"{model_group}_request_count"
# update cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = (
request_count_dict.get(deployment, 0) + 1
)
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e:
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs['litellm_params'].get('metadata') is None:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
deployment = kwargs["litellm_params"]["metadata"].get(
"deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
if deployment is None or model_group is None:
return
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e:
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs['litellm_params'].get('metadata') is None:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
deployment = kwargs["litellm_params"]["metadata"].get(
"deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
if deployment is None or model_group is None:
return
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e:
pass
def get_available_deployments(self, model_group: str):
request_count_api_key = f"{model_group}_request_count"
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
# map deployment to id
return_dict = {}
for key, value in request_count_dict.items():

View file

@ -2,6 +2,7 @@
import pytest, sys, os
import importlib
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@ -14,17 +15,23 @@ def setup_and_teardown():
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
"""
curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(0, os.path.abspath("../..")) # Adds the project directory to the system path
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
import litellm
importlib.reload(litellm)
print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding
yield
def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
custom_logger_tests = [item for item in items if 'custom_logger' in item.parent.name]
other_tests = [item for item in items if 'custom_logger' not in item.parent.name]
custom_logger_tests = [
item for item in items if "custom_logger" in item.parent.name
]
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
# Sort tests based on their names
custom_logger_tests.sort(key=lambda x: x.name)

View file

@ -4,6 +4,7 @@
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@ -11,15 +12,17 @@ import litellm
from litellm import Router
import concurrent
from dotenv import load_dotenv
load_dotenv()
model_list = [{ # list of model deployments
model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
@ -31,24 +34,30 @@ model_list = [{ # list of model deployments
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
}
"rpm": 9000,
},
]
kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],}
kwargs = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
}
def test_multiple_deployments_sync():
import concurrent, time
litellm.set_verbose=False
litellm.set_verbose = False
results = []
router = Router(model_list=model_list,
router = Router(
model_list=model_list,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
routing_strategy="simple-shuffle",
set_verbose=True,
num_retries=1) # type: ignore
num_retries=1,
) # type: ignore
try:
for _ in range(3):
response = router.completion(**kwargs)
@ -59,6 +68,7 @@ def test_multiple_deployments_sync():
print(f"FAILED TEST!")
pytest.fail(f"An error occurred - {traceback.format_exc()}")
# test_multiple_deployments_sync()
@ -67,13 +77,15 @@ def test_multiple_deployments_parallel():
results = []
futures = {}
start_time = time.time()
router = Router(model_list=model_list,
router = Router(
model_list=model_list,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
routing_strategy="simple-shuffle",
set_verbose=True,
num_retries=1) # type: ignore
num_retries=1,
) # type: ignore
# Assuming you have an executor instance defined somewhere in your code
with concurrent.futures.ThreadPoolExecutor() as executor:
for _ in range(5):
@ -82,7 +94,11 @@ def test_multiple_deployments_parallel():
# Retrieve the results from the futures
while futures:
done, not_done = concurrent.futures.wait(futures.values(), timeout=10, return_when=concurrent.futures.FIRST_COMPLETED)
done, not_done = concurrent.futures.wait(
futures.values(),
timeout=10,
return_when=concurrent.futures.FIRST_COMPLETED,
)
for future in done:
try:
result = future.result()
@ -98,8 +114,10 @@ def test_multiple_deployments_parallel():
print(results)
print(f"ELAPSED TIME: {end_time - start_time}")
# Assuming litellm, router, and executor are defined somewhere in your code
# test_multiple_deployments_parallel()
def test_cooldown_same_model_name():
# users could have the same model with different api_base
@ -118,7 +136,7 @@ def test_cooldown_same_model_name():
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": "BAD_API_BASE",
"tpm": 90
"tpm": 90,
},
},
{
@ -128,7 +146,7 @@ def test_cooldown_same_model_name():
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
"tpm": 0.000001
"tpm": 0.000001,
},
},
]
@ -140,17 +158,12 @@ def test_cooldown_same_model_name():
redis_port=int(os.getenv("REDIS_PORT")),
routing_strategy="simple-shuffle",
set_verbose=True,
num_retries=3
num_retries=3,
) # type: ignore
response = router.completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "hello this request will pass"
}
]
messages=[{"role": "user", "content": "hello this request will pass"}],
)
print(router.model_list)
model_ids = []
@ -159,10 +172,13 @@ def test_cooldown_same_model_name():
print("\n litellm model ids ", model_ids)
# example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960']
assert model_ids[0] != model_ids[1] # ensure both models have a uuid added, and they have different names
assert (
model_ids[0] != model_ids[1]
) # ensure both models have a uuid added, and they have different names
print("\ngot response\n", response)
except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {e}")
test_cooldown_same_model_name()

View file

@ -9,11 +9,12 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
## case 1: set_function_to_prompt not set
def test_function_call_non_openai_model():
try:
model = "claude-instant-1"
messages=[{"role": "user", "content": "what's the weather in sf?"}]
messages = [{"role": "user", "content": "what's the weather in sf?"}]
functions = [
{
"name": "get_current_weather",
@ -23,31 +24,32 @@ def test_function_call_non_openai_model():
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
"required": ["location"]
}
}
]
response = litellm.completion(model=model, messages=messages, functions=functions)
pytest.fail(f'An error occurred')
response = litellm.completion(
model=model, messages=messages, functions=functions
)
pytest.fail(f"An error occurred")
except Exception as e:
print(e)
pass
test_function_call_non_openai_model()
## case 2: add_function_to_prompt set
def test_function_call_non_openai_model_litellm_mod_set():
litellm.add_function_to_prompt = True
try:
model = "claude-instant-1"
messages=[{"role": "user", "content": "what's the weather in sf?"}]
messages = [{"role": "user", "content": "what's the weather in sf?"}]
functions = [
{
"name": "get_current_weather",
@ -57,20 +59,20 @@ def test_function_call_non_openai_model_litellm_mod_set():
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
"required": ["location"]
}
}
]
response = litellm.completion(model=model, messages=messages, functions=functions)
print(f'response: {response}')
response = litellm.completion(
model=model, messages=messages, functions=functions
)
print(f"response: {response}")
except Exception as e:
pytest.fail(f'An error occurred {e}')
pytest.fail(f"An error occurred {e}")
# test_function_call_non_openai_model_litellm_mod_set()

View file

@ -1,4 +1,3 @@
import sys, os
import traceback
from dotenv import load_dotenv
@ -27,11 +26,11 @@ def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + '/vertex_key.json'
vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, 'r') as file:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
@ -55,13 +54,13 @@ def load_vertex_ai_credentials():
service_account_key_data["private_key"] = private_key
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file:
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary file
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
@pytest.mark.asyncio
async def get_response():
@ -89,43 +88,80 @@ def test_vertex_ai():
import random
load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
litellm.set_verbose=False
test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
litellm.set_verbose = False
litellm.vertex_project = "reliablekeys"
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models:
try:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model
continue
print("making request", model)
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}], temperature=0.7)
response = completion(
model=model,
messages=[{"role": "user", "content": "hi"}],
temperature=0.7,
)
print("\nModel Response", response)
print(response)
assert type(response.choices[0].message.content) == str
assert len(response.choices[0].message.content) > 1
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_vertex_ai()
def test_vertex_ai_stream():
load_vertex_ai_credentials()
litellm.set_verbose=False
litellm.set_verbose = False
litellm.vertex_project = "reliablekeys"
import random
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models:
try:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model
continue
print("making request", model)
response = completion(model=model, messages=[{"role": "user", "content": "write 10 line code code for saying hi"}], stream=True)
response = completion(
model=model,
messages=[
{"role": "user", "content": "write 10 line code code for saying hi"}
],
stream=True,
)
completed_str = ""
for chunk in response:
print(chunk)
@ -137,47 +173,86 @@ def test_vertex_ai_stream():
assert len(completed_str) > 4
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_vertex_ai_stream()
@pytest.mark.asyncio
async def test_async_vertexai_response():
import random
load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models:
print(f'model being tested in async call: {model}')
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
print(f"model being tested in async call: {model}")
if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model
continue
try:
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model=model, messages=messages, temperature=0.7, timeout=5)
response = await acompletion(
model=model, messages=messages, temperature=0.7, timeout=5
)
print(f"response: {response}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
# asyncio.run(test_async_vertexai_response())
@pytest.mark.asyncio
async def test_async_vertexai_streaming_response():
import random
load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model
continue
try:
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5, stream=True)
response = await acompletion(
model="gemini-pro",
messages=messages,
temperature=0.7,
timeout=5,
stream=True,
)
print(f"response: {response}")
complete_response = ""
async for chunk in response:
@ -191,38 +266,40 @@ async def test_async_vertexai_streaming_response():
print(e)
pytest.fail(f"An exception occurred: {e}")
# asyncio.run(test_async_vertexai_streaming_response())
def test_gemini_pro_vision():
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
litellm.num_retries=0
litellm.num_retries = 0
resp = litellm.completion(
model = "vertex_ai/gemini-pro-vision",
model="vertex_ai/gemini-pro-vision",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?"
},
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
}
}
]
},
},
],
}
],
)
print(resp)
except Exception as e:
import traceback
traceback.print_exc()
raise e
# test_gemini_pro_vision()

View file

@ -11,8 +11,10 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
from litellm import completion, acompletion, acreate
litellm.num_retries = 3
def test_sync_response():
litellm.set_verbose = False
user_message = "Hello, how are you?"
@ -24,28 +26,42 @@ def test_sync_response():
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
# test_sync_response()
def test_sync_response_anyscale():
litellm.set_verbose = False
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
response = completion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
timeout=5,
)
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
# test_sync_response_anyscale()
def test_async_response_openai():
import asyncio
litellm.set_verbose = True
async def test_get_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages, timeout=5)
response = await acompletion(
model="gpt-3.5-turbo", messages=messages, timeout=5
)
print(f"response: {response}")
print(f"response ms: {response._response_ms}")
except litellm.Timeout as e:
@ -56,16 +72,25 @@ def test_async_response_openai():
asyncio.run(test_get_response())
# test_async_response_openai()
def test_async_response_azure():
import asyncio
litellm.set_verbose = True
async def test_get_response():
user_message = "What do you know?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="azure/gpt-turbo", messages=messages, base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), api_key=os.getenv("AZURE_FRANCE_API_KEY"))
response = await acompletion(
model="azure/gpt-turbo",
messages=messages,
base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"),
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
)
print(f"response: {response}")
except litellm.Timeout as e:
pass
@ -74,17 +99,24 @@ def test_async_response_azure():
asyncio.run(test_get_response())
# test_async_response_azure()
def test_async_anyscale_response():
import asyncio
litellm.set_verbose = True
async def test_get_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
response = await acompletion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
timeout=5,
)
# response = await response
print(f"response: {response}")
except litellm.Timeout as e:
@ -94,16 +126,21 @@ def test_async_anyscale_response():
asyncio.run(test_get_response())
# test_async_anyscale_response()
def test_get_response_streaming():
import asyncio
async def test_async_call():
user_message = "write a short poem in one sentence"
messages = [{"content": user_message, "role": "user"}]
try:
litellm.set_verbose = True
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5)
response = await acompletion(
model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5
)
print(type(response))
import inspect
@ -121,24 +158,34 @@ def test_get_response_streaming():
assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0."
print(f'output: {output}')
print(f"output: {output}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_async_call())
# test_get_response_streaming()
def test_get_response_non_openai_streaming():
import asyncio
litellm.set_verbose = True
litellm.num_retries = 0
async def test_async_call():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True, timeout=5)
response = await acompletion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
stream=True,
timeout=5,
)
print(type(response))
import inspect
@ -163,6 +210,8 @@ def test_get_response_non_openai_streaming():
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
return response
asyncio.run(test_async_call())
# test_get_response_non_openai_streaming()

View file

@ -3,14 +3,15 @@
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
import openai, litellm, uuid
from openai import AsyncAzureOpenAI
client = AsyncAzureOpenAI(
api_key=os.getenv("AZURE_API_KEY"),
azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
api_version=os.getenv("AZURE_API_VERSION")
api_version=os.getenv("AZURE_API_VERSION"),
)
model_list = [
@ -20,59 +21,84 @@ model_list = [
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
}
"api_version": os.getenv("AZURE_API_VERSION"),
},
}
]
router = litellm.Router(model_list=model_list)
async def _openai_completion():
try:
start_time = time.time()
response = await client.chat.completions.create(
model="chatgpt-v-2",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
stream=True
stream=True,
)
time_to_first_token = None
first_token_ts = None
init_chunk = None
async for chunk in response:
if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
if (
time_to_first_token is None
and len(chunk.choices) > 0
and chunk.choices[0].delta.content is not None
):
first_token_ts = time.time()
time_to_first_token = first_token_ts - start_time
init_chunk = chunk
end_time = time.time()
print("OpenAI Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time)
print(
"OpenAI Call: ",
init_chunk,
start_time,
first_token_ts,
time_to_first_token,
end_time,
)
return time_to_first_token
except Exception as e:
print(e)
return None
async def _router_completion():
try:
start_time = time.time()
response = await router.acompletion(
model="azure-test",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
stream=True
stream=True,
)
time_to_first_token = None
first_token_ts = None
init_chunk = None
async for chunk in response:
if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
if (
time_to_first_token is None
and len(chunk.choices) > 0
and chunk.choices[0].delta.content is not None
):
first_token_ts = time.time()
time_to_first_token = first_token_ts - start_time
init_chunk = chunk
end_time = time.time()
print("Router Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time - first_token_ts)
print(
"Router Call: ",
init_chunk,
start_time,
first_token_ts,
time_to_first_token,
end_time - first_token_ts,
)
return time_to_first_token
except Exception as e:
print(e)
return None
async def test_azure_completion_streaming():
"""
Test azure streaming call - measure on time to first (non-null) token.
@ -85,7 +111,7 @@ async def test_azure_completion_streaming():
total_time = 0
for item in successful_completions:
total_time += item
avg_openai_time = total_time/3
avg_openai_time = total_time / 3
## ROUTER AVG. TIME
tasks = [_router_completion() for _ in range(n)]
chat_completions = await asyncio.gather(*tasks)
@ -93,9 +119,10 @@ async def test_azure_completion_streaming():
total_time = 0
for item in successful_completions:
total_time += item
avg_router_time = total_time/3
avg_router_time = total_time / 3
## COMPARE
print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}")
assert avg_router_time < avg_openai_time + 0.5
# asyncio.run(test_azure_completion_streaming())

View file

@ -5,6 +5,7 @@
import sys, os
import traceback
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@ -18,6 +19,7 @@ user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
model_val = None
def test_completion_with_no_model():
# test on empty
with pytest.raises(ValueError):
@ -32,6 +34,7 @@ def test_completion_with_empty_model():
print(f"error occurred: {e}")
pass
# def test_completion_catch_nlp_exception():
# TEMP commented out NLP cloud API is unstable
# try:
@ -64,6 +67,7 @@ def test_completion_with_empty_model():
# test_completion_catch_nlp_exception()
def test_completion_invalid_param_cohere():
try:
response = completion(model="command-nightly", messages=messages, top_p=1)
@ -72,14 +76,18 @@ def test_completion_invalid_param_cohere():
if "Unsupported parameters passed: top_p" in str(e):
pass
else:
pytest.fail(f'An error occurred {e}')
pytest.fail(f"An error occurred {e}")
# test_completion_invalid_param_cohere()
def test_completion_function_call_cohere():
try:
response = completion(model="command-nightly", messages=messages, functions=["TEST-FUNCTION"])
pytest.fail(f'An error occurred {e}')
response = completion(
model="command-nightly", messages=messages, functions=["TEST-FUNCTION"]
)
pytest.fail(f"An error occurred {e}")
except Exception as e:
print(e)
pass
@ -87,10 +95,14 @@ def test_completion_function_call_cohere():
# test_completion_function_call_cohere()
def test_completion_function_call_openai():
try:
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
response = completion(model="gpt-3.5-turbo", messages=messages, functions=[
response = completion(
model="gpt-3.5-turbo",
messages=messages,
functions=[
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
@ -99,23 +111,26 @@ def test_completion_function_call_openai():
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
"required": ["location"]
}
}
])
],
)
print(f"response: {response}")
except:
pass
# test_completion_function_call_openai()
def test_completion_with_no_provider():
# test on empty
try:
@ -125,6 +140,7 @@ def test_completion_with_no_provider():
print(f"error occurred: {e}")
pass
# test_completion_with_no_provider()
# # bad key
# temp_key = os.environ.get("OPENAI_API_KEY")

View file

@ -4,15 +4,24 @@
import sys, os
import traceback
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from openai import APITimeoutError as Timeout
import litellm
litellm.num_retries = 0
from litellm import batch_completion, batch_completion_models, completion, batch_completion_models_all_responses
from litellm import (
batch_completion,
batch_completion_models,
completion,
batch_completion_models_all_responses,
)
# litellm.set_verbose=True
def test_batch_completions():
messages = [[{"role": "user", "content": "write a short poem"}] for _ in range(3)]
model = "j2-mid"
@ -23,43 +32,50 @@ def test_batch_completions():
messages=messages,
max_tokens=10,
temperature=0.2,
request_timeout=1
request_timeout=1,
)
print(result)
print(len(result))
assert(len(result)==3)
assert len(result) == 3
except Timeout as e:
print(f"IN TIMEOUT")
pass
except Exception as e:
pytest.fail(f"An error occurred: {e}")
test_batch_completions()
def test_batch_completions_models():
try:
result = batch_completion_models(
models=["gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-3.5-turbo"],
messages=[{"role": "user", "content": "Hey, how's it going"}]
messages=[{"role": "user", "content": "Hey, how's it going"}],
)
print(result)
except Timeout as e:
pass
except Exception as e:
pytest.fail(f"An error occurred: {e}")
# test_batch_completions_models()
def test_batch_completion_models_all_responses():
try:
responses = batch_completion_models_all_responses(
models=["j2-light", "claude-instant-1.2"],
messages=[{"role": "user", "content": "write a poem"}],
max_tokens=10
max_tokens=10,
)
print(responses)
assert(len(responses) == 2)
assert len(responses) == 2
except Timeout as e:
pass
except Exception as e:
pytest.fail(f"An error occurred: {e}")
# test_batch_completion_models_all_responses()
# test_batch_completion_models_all_responses()

View file

@ -14,6 +14,7 @@ import litellm
from litellm import embedding, completion
from litellm.caching import Cache
import random
# litellm.set_verbose=True
messages = [{"role": "user", "content": "who is ishaan Github? "}]
@ -22,14 +23,18 @@ messages = [{"role": "user", "content": "who is ishaan Github? "}]
import random
import string
def generate_random_word(length=4):
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for _ in range(length))
return "".join(random.choice(letters) for _ in range(length))
messages = [{"role": "user", "content": "who is ishaan 5222"}]
def test_caching_v2(): # test in memory cache
try:
litellm.set_verbose=True
litellm.set_verbose = True
litellm.cache = Cache()
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
@ -38,7 +43,10 @@ def test_caching_v2(): # test in memory cache
litellm.cache = None # disable cache
litellm.success_callback = []
litellm._async_success_callback = []
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
@ -46,12 +54,14 @@ def test_caching_v2(): # test in memory cache
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
# test_caching_v2()
def test_caching_with_models_v2():
messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}]
messages = [
{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}
]
litellm.cache = Cache()
print("test2 for caching")
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
@ -63,34 +73,51 @@ def test_caching_with_models_v2():
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
if (
response3["choices"][0]["message"]["content"]
== response2["choices"][0]["message"]["content"]
):
# if models are different, it should not return cached response
print(f"response2: {response2}")
print(f"response3: {response3}")
pytest.fail(f"Error occurred:")
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']:
if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
):
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
# test_caching_with_models_v2()
embedding_large_text = """
embedding_large_text = (
"""
small text
""" * 5
"""
* 5
)
# # test_caching_with_models()
def test_embedding_caching():
import time
litellm.cache = Cache()
text_to_embed = [embedding_large_text]
start_time = time.time()
embedding1 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True)
embedding1 = embedding(
model="text-embedding-ada-002", input=text_to_embed, caching=True
)
end_time = time.time()
print(f"Embedding 1 response time: {end_time - start_time} seconds")
time.sleep(1)
start_time = time.time()
embedding2 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True)
embedding2 = embedding(
model="text-embedding-ada-002", input=text_to_embed, caching=True
)
end_time = time.time()
print(f"embedding2: {embedding2}")
print(f"Embedding 2 response time: {end_time - start_time} seconds")
@ -99,28 +126,29 @@ def test_embedding_caching():
litellm.success_callback = []
litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]:
print(f"embedding1: {embedding1}")
print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed")
# test_embedding_caching()
def test_embedding_caching_azure():
print("Testing azure embedding caching")
import time
litellm.cache = Cache()
text_to_embed = [embedding_large_text]
api_key = os.environ['AZURE_API_KEY']
api_base = os.environ['AZURE_API_BASE']
api_version = os.environ['AZURE_API_VERSION']
os.environ['AZURE_API_VERSION'] = ""
os.environ['AZURE_API_BASE'] = ""
os.environ['AZURE_API_KEY'] = ""
api_key = os.environ["AZURE_API_KEY"]
api_base = os.environ["AZURE_API_BASE"]
api_version = os.environ["AZURE_API_VERSION"]
os.environ["AZURE_API_VERSION"] = ""
os.environ["AZURE_API_BASE"] = ""
os.environ["AZURE_API_KEY"] = ""
start_time = time.time()
print("AZURE CONFIGS")
@ -133,7 +161,7 @@ def test_embedding_caching_azure():
api_key=api_key,
api_base=api_base,
api_version=api_version,
caching=True
caching=True,
)
end_time = time.time()
print(f"Embedding 1 response time: {end_time - start_time} seconds")
@ -146,7 +174,7 @@ def test_embedding_caching_azure():
api_key=api_key,
api_base=api_base,
api_version=api_version,
caching=True
caching=True,
)
end_time = time.time()
print(f"Embedding 2 response time: {end_time - start_time} seconds")
@ -155,14 +183,15 @@ def test_embedding_caching_azure():
litellm.success_callback = []
litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]:
print(f"embedding1: {embedding1}")
print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed")
os.environ['AZURE_API_VERSION'] = api_version
os.environ['AZURE_API_BASE'] = api_base
os.environ['AZURE_API_KEY'] = api_key
os.environ["AZURE_API_VERSION"] = api_version
os.environ["AZURE_API_BASE"] = api_base
os.environ["AZURE_API_KEY"] = api_key
# test_embedding_caching_azure()
@ -170,13 +199,28 @@ def test_embedding_caching_azure():
def test_redis_cache_completion():
litellm.set_verbose = False
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test2 for caching")
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20)
response3 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5)
response1 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
)
response2 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
)
response3 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5
)
response4 = completion(model="command-nightly", messages=messages, caching=True)
print("\nresponse 1", response1)
@ -192,49 +236,88 @@ def test_redis_cache_completion():
1 & 3 should be different, since input params are diff
1 & 4 should be diff, since models are diff
"""
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: # 1 and 2 should be the same
if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
): # 1 and 2 should be the same
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
if response1['choices'][0]['message']['content'] == response3['choices'][0]['message']['content']:
if (
response1["choices"][0]["message"]["content"]
== response3["choices"][0]["message"]["content"]
):
# if input params like seed, max_tokens are diff it should NOT be a cache hit
print(f"response1: {response1}")
print(f"response3: {response3}")
pytest.fail(f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:")
if response1['choices'][0]['message']['content'] == response4['choices'][0]['message']['content']:
pytest.fail(
f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:"
)
if (
response1["choices"][0]["message"]["content"]
== response4["choices"][0]["message"]["content"]
):
# if models are different, it should not return cached response
print(f"response1: {response1}")
print(f"response4: {response4}")
pytest.fail(f"Error occurred:")
# test_redis_cache_completion()
def test_redis_cache_completion_stream():
try:
litellm.success_callback = []
litellm._async_success_callback = []
litellm.callbacks = []
litellm.set_verbose = True
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion")
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=0.2,
stream=True,
)
response_1_content = ""
for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
time.sleep(0.5)
response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=0.2,
stream=True,
)
response_2_content = ""
for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = []
litellm.cache = None
litellm.success_callback = []
@ -247,99 +330,171 @@ def test_redis_cache_completion_stream():
1 & 2 should be exactly the same
"""
# test_redis_cache_completion_stream()
def test_redis_cache_acompletion_stream():
import asyncio
try:
litellm.set_verbose = True
random_word = generate_random_word()
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
messages = [
{
"role": "user",
"content": f"write a one sentence poem about: {random_word}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion")
response_1_content = ""
response_2_content = ""
async def call1():
nonlocal response_1_content
response1 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True)
response1 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
asyncio.run(call1())
time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2():
nonlocal response_2_content
response2 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True)
response2 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
asyncio.run(call2())
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e:
print(e)
raise e
# test_redis_cache_acompletion_stream()
def test_redis_cache_acompletion_stream_bedrock():
import asyncio
try:
litellm.set_verbose = True
random_word = generate_random_word()
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
messages = [
{
"role": "user",
"content": f"write a one sentence poem about: {random_word}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion")
response_1_content = ""
response_2_content = ""
async def call1():
nonlocal response_1_content
response1 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
response1 = await litellm.acompletion(
model="bedrock/anthropic.claude-v1",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
asyncio.run(call1())
time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2():
nonlocal response_2_content
response2 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
response2 = await litellm.acompletion(
model="bedrock/anthropic.claude-v1",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
asyncio.run(call2())
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e:
print(e)
raise e
# test_redis_cache_acompletion_stream_bedrock()
# redis cache with custom keys
def custom_get_cache_key(*args, **kwargs):
# return key to use for your cache:
key = kwargs.get("model", "") + str(kwargs.get("messages", "")) + str(kwargs.get("temperature", "")) + str(kwargs.get("logit_bias", ""))
key = (
kwargs.get("model", "")
+ str(kwargs.get("messages", ""))
+ str(kwargs.get("temperature", ""))
+ str(kwargs.get("logit_bias", ""))
)
return key
def test_custom_redis_cache_with_key():
messages = [{"role": "user", "content": "write a one line story"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.cache.get_cache_key = custom_get_cache_key
local_cache = {}
@ -356,53 +511,71 @@ def test_custom_redis_cache_with_key():
# patch this redis cache get and set call
response1 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3)
response2 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3)
response3 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=False, num_retries=3)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
temperature=1,
caching=True,
num_retries=3,
)
response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
temperature=1,
caching=True,
num_retries=3,
)
response3 = completion(
model="gpt-3.5-turbo",
messages=messages,
temperature=1,
caching=False,
num_retries=3,
)
print(f"response1: {response1}")
print(f"response2: {response2}")
print(f"response3: {response3}")
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
if (
response3["choices"][0]["message"]["content"]
== response2["choices"][0]["message"]["content"]
):
pytest.fail(f"Error occurred:")
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# test_custom_redis_cache_with_key()
def test_cache_override():
# test if we can override the cache, when `caching=False` but litellm.cache = Cache() is set
# in this case it should not return cached responses
litellm.cache = Cache()
print("Testing cache override")
litellm.set_verbose=True
litellm.set_verbose = True
# test embedding
response1 = embedding(
model = "text-embedding-ada-002",
input=[
"hello who are you"
],
caching = False
model="text-embedding-ada-002", input=["hello who are you"], caching=False
)
start_time = time.time()
response2 = embedding(
model = "text-embedding-ada-002",
input=[
"hello who are you"
],
caching = False
model="text-embedding-ada-002", input=["hello who are you"], caching=False
)
end_time = time.time()
print(f"Embedding 2 response time: {end_time - start_time} seconds")
assert end_time - start_time > 0.1 # ensure 2nd response comes in over 0.1s. This should not be cached.
assert (
end_time - start_time > 0.1
) # ensure 2nd response comes in over 0.1s. This should not be cached.
# test_cache_override()
@ -411,10 +584,10 @@ def test_custom_redis_cache_params():
try:
litellm.cache = Cache(
type="redis",
host=os.environ['REDIS_HOST'],
port=os.environ['REDIS_PORT'],
password=os.environ['REDIS_PASSWORD'],
db = 0,
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
db=0,
ssl=True,
ssl_certfile="./redis_user.crt",
ssl_keyfile="./redis_user_private.key",
@ -431,58 +604,126 @@ def test_custom_redis_cache_params():
def test_get_cache_key():
from litellm.caching import Cache
try:
print("Testing get_cache_key")
cache_instance = Cache()
cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
cache_key = cache_instance.get_cache_key(
**{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "write a one sentence poem about: 7510"}
],
"max_tokens": 40,
"temperature": 0.2,
"stream": True,
"litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b",
"litellm_logging_obj": {},
}
)
cache_key_2 = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
cache_key_2 = cache_instance.get_cache_key(
**{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "write a one sentence poem about: 7510"}
],
"max_tokens": 40,
"temperature": 0.2,
"stream": True,
"litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b",
"litellm_logging_obj": {},
}
)
assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
assert cache_key == cache_key_2, f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs"
assert (
cache_key
== "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
)
assert (
cache_key == cache_key_2
), f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs"
embedding_cache_key = cache_instance.get_cache_key(
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
'api_key': '', 'api_version': '2023-07-01-preview',
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
'caching': True,
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>"
**{
"model": "azure/azure-embedding-model",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"api_key": "",
"api_version": "2023-07-01-preview",
"timeout": None,
"max_retries": 0,
"input": ["hi who is ishaan"],
"caching": True,
"client": "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
}
)
print(embedding_cache_key)
assert embedding_cache_key == "model: azure/azure-embedding-modelinput: ['hi who is ishaan']", f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs"
assert (
embedding_cache_key
== "model: azure/azure-embedding-modelinput: ['hi who is ishaan']"
), f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs"
# Proxy - embedding cache, test if embedding key, gets model_group and not model
embedding_cache_key_2 = cache_instance.get_cache_key(
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
'api_key': '', 'api_version': '2023-07-01-preview',
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
'caching': True,
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
'proxy_server_request': {'url': 'http://0.0.0.0:8000/embeddings',
'method': 'POST',
'headers':
{'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json',
'content-length': '80'},
'body': {'model': 'azure-embedding-model', 'input': ['hi who is ishaan']}},
'user': None,
'metadata': {'user_api_key': None,
'headers': {'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', 'content-length': '80'},
'model_group': 'EMBEDDING_MODEL_GROUP',
'deployment': 'azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview'},
'model_info': {'mode': 'embedding', 'base_model': 'text-embedding-ada-002', 'id': '20b2b515-f151-4dd5-a74f-2231e2f54e29'},
'litellm_call_id': '2642e009-b3cd-443d-b5dd-bb7d56123b0e', 'litellm_logging_obj': '<litellm.utils.Logging object at 0x12f1bddb0>'}
**{
"model": "azure/azure-embedding-model",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"api_key": "",
"api_version": "2023-07-01-preview",
"timeout": None,
"max_retries": 0,
"input": ["hi who is ishaan"],
"caching": True,
"client": "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
"proxy_server_request": {
"url": "http://0.0.0.0:8000/embeddings",
"method": "POST",
"headers": {
"host": "0.0.0.0:8000",
"user-agent": "curl/7.88.1",
"accept": "*/*",
"content-type": "application/json",
"content-length": "80",
},
"body": {
"model": "azure-embedding-model",
"input": ["hi who is ishaan"],
},
},
"user": None,
"metadata": {
"user_api_key": None,
"headers": {
"host": "0.0.0.0:8000",
"user-agent": "curl/7.88.1",
"accept": "*/*",
"content-type": "application/json",
"content-length": "80",
},
"model_group": "EMBEDDING_MODEL_GROUP",
"deployment": "azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview",
},
"model_info": {
"mode": "embedding",
"base_model": "text-embedding-ada-002",
"id": "20b2b515-f151-4dd5-a74f-2231e2f54e29",
},
"litellm_call_id": "2642e009-b3cd-443d-b5dd-bb7d56123b0e",
"litellm_logging_obj": "<litellm.utils.Logging object at 0x12f1bddb0>",
}
)
print(embedding_cache_key_2)
assert embedding_cache_key_2 == "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']"
assert (
embedding_cache_key_2
== "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']"
)
print("passed!")
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred:", e)
test_get_cache_key()
# test_custom_redis_cache_params()

View file

@ -18,15 +18,26 @@ from litellm import embedding, completion, Router
from litellm.caching import Cache
messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}]
def test_caching_v2(): # test in memory cache
try:
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2")
litellm.cache = Cache(
type="redis",
host="os.environ/REDIS_HOST_2",
port="os.environ/REDIS_PORT_2",
password="os.environ/REDIS_PASSWORD_2",
ssl="os.environ/REDIS_SSL_2",
)
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}")
print(f"response2: {response2}")
raise Exception()
@ -34,6 +45,7 @@ def test_caching_v2(): # test in memory cache
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
# test_caching_v2()
@ -49,26 +61,41 @@ def test_caching_router():
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800
"rpm": 1800,
}
]
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2")
router = Router(model_list=model_list,
litellm.cache = Cache(
type="redis",
host="os.environ/REDIS_HOST_2",
port="os.environ/REDIS_PORT_2",
password="os.environ/REDIS_PASSWORD_2",
ssl="os.environ/REDIS_SSL_2",
)
router = Router(
model_list=model_list,
routing_strategy="simple-shuffle",
set_verbose=False,
num_retries=1) # type: ignore
num_retries=1,
) # type: ignore
response1 = completion(model="gpt-3.5-turbo", messages=messages)
response2 = completion(model="gpt-3.5-turbo", messages=messages)
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
assert response2['choices'][0]['message']['content'] == response1['choices'][0]['message']['content']
assert (
response2["choices"][0]["message"]["content"]
== response1["choices"][0]["message"]["content"]
)
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
# test_caching_router()

File diff suppressed because it is too large Load diff

View file

@ -28,6 +28,7 @@ def logger_fn(user_model_dict):
# print(f"user_model_dict: {user_model_dict}")
pass
# normal call
def test_completion_custom_provider_model_name():
try:
@ -41,25 +42,31 @@ def test_completion_custom_provider_model_name():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# completion with num retries + impact on exception mapping
def test_completion_with_num_retries():
try:
response = completion(model="j2-ultra", messages=[{"messages": "vibe", "bad": "message"}], num_retries=2)
response = completion(
model="j2-ultra",
messages=[{"messages": "vibe", "bad": "message"}],
num_retries=2,
)
pytest.fail(f"Unmapped exception occurred")
except Exception as e:
pass
# test_completion_with_num_retries()
def test_completion_with_0_num_retries():
try:
litellm.set_verbose=False
litellm.set_verbose = False
print("making request")
# Use the completion function
response = completion(
model="gpt-3.5-turbo",
messages=[{"gm": "vibe", "role": "user"}],
max_retries=4
max_retries=4,
)
print(response)
@ -69,5 +76,6 @@ def test_completion_with_0_num_retries():
print("exception", e)
pass
# Call the test function
test_completion_with_0_num_retries()

View file

@ -15,77 +15,104 @@ from litellm import completion_with_config
config = {
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
"model": {
"claude-instant-1": {
"needs_moderation": True
},
"claude-instant-1": {"needs_moderation": True},
"gpt-3.5-turbo": {
"error_handling": {
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
}
}
}
},
},
}
def test_config_context_window_exceeded():
try:
sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}]
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config)
response = completion_with_config(
model="gpt-3.5-turbo", messages=messages, config=config
)
print(response)
except Exception as e:
print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}")
# test_config_context_window_exceeded()
def test_config_context_moderation():
try:
messages=[{"role": "user", "content": "I want to kill them."}]
response = completion_with_config(model="claude-instant-1", messages=messages, config=config)
messages = [{"role": "user", "content": "I want to kill them."}]
response = completion_with_config(
model="claude-instant-1", messages=messages, config=config
)
print(response)
except Exception as e:
print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}")
# test_config_context_moderation()
def test_config_context_default_fallback():
try:
messages=[{"role": "user", "content": "Hey, how's it going?"}]
response = completion_with_config(model="claude-instant-1", messages=messages, config=config, api_key="bad-key")
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = completion_with_config(
model="claude-instant-1",
messages=messages,
config=config,
api_key="bad-key",
)
print(response)
except Exception as e:
print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}")
# test_config_context_default_fallback()
config = {
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
"available_models": ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-0314", "gpt-4-0613",
"j2-ultra", "command-nightly", "togethercomputer/llama-2-70b-chat", "chat-bison", "chat-bison@001", "claude-2"],
"available_models": [
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"j2-ultra",
"command-nightly",
"togethercomputer/llama-2-70b-chat",
"chat-bison",
"chat-bison@001",
"claude-2",
],
"adapt_to_prompt_size": True, # type: ignore
"model": {
"claude-instant-1": {
"needs_moderation": True
},
"claude-instant-1": {"needs_moderation": True},
"gpt-3.5-turbo": {
"error_handling": {
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
}
}
}
},
},
}
def test_config_context_adapt_to_prompt():
try:
sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}]
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config)
response = completion_with_config(
model="gpt-3.5-turbo", messages=messages, config=config
)
print(response)
except Exception as e:
print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}")
test_config_context_adapt_to_prompt()

View file

@ -4,6 +4,8 @@ from dotenv import load_dotenv
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
print(f"api_key: {api_key}")

View file

@ -2,6 +2,7 @@ from litellm.integrations.custom_logger import CustomLogger
import inspect
import litellm
class testCustomCallbackProxy(CustomLogger):
def __init__(self):
self.success: bool = False # type: ignore
@ -24,7 +25,11 @@ class testCustomCallbackProxy(CustomLogger):
print(f"{blue_color_code}Initialized LiteLLM custom logger")
try:
print(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
methods = [
method
for method in dir(self)
if inspect.ismethod(getattr(self, method))
]
# Pretty print the methods
for method in methods:
@ -55,7 +60,10 @@ class testCustomCallbackProxy(CustomLogger):
self.async_success = True
print("Value of async success: ", self.async_success)
print("\n kwargs: ", kwargs)
if kwargs.get("model") == "azure-embedding-model" or kwargs.get("model") == "ada":
if (
kwargs.get("model") == "azure-embedding-model"
or kwargs.get("model") == "ada"
):
print("Got an embedding model", kwargs.get("model"))
print("Setting embedding success to True")
self.async_success_embedding = True
@ -65,7 +73,6 @@ class testCustomCallbackProxy(CustomLogger):
if kwargs.get("stream") == True:
self.streaming_response_obj = response_obj
self.async_completion_kwargs = kwargs
model = kwargs.get("model", None)
@ -74,7 +81,9 @@ class testCustomCallbackProxy(CustomLogger):
# Access litellm_params passed to litellm.completion(), example access `metadata`
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here
metadata = litellm_params.get(
"metadata", {}
) # headers passed to LiteLLM proxy, can be found here
# Calculate cost using litellm.completion_cost()
cost = litellm.completion_cost(completion_response=response_obj)
@ -84,7 +93,6 @@ class testCustomCallbackProxy(CustomLogger):
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
print(
f"""
Model: {model},
@ -98,7 +106,6 @@ class testCustomCallbackProxy(CustomLogger):
)
return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure")
self.async_failure = True
@ -110,4 +117,5 @@ class testCustomCallbackProxy(CustomLogger):
self.async_completion_kwargs_fail = kwargs
my_custom_logger = testCustomCallbackProxy()

View file

@ -3,7 +3,8 @@
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List, Union
from litellm import completion, embedding, Cache
import litellm
@ -25,14 +26,32 @@ from litellm.integrations.custom_logger import CustomLogger
## 1. litellm.completion() + litellm.embeddings()
## refer to test_custom_callback_input_router.py for the router + proxy tests
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
class CompletionCustomHandler(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
self.states: Optional[
List[
Literal[
"sync_pre_api_call",
"async_pre_api_call",
"post_api_call",
"sync_stream",
"async_stream",
"sync_success",
"async_success",
"sync_failure",
"async_failure",
]
]
] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
@ -42,13 +61,13 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## MESSAGES
assert isinstance(messages, list)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -63,18 +82,24 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.iscoroutine(kwargs["original_response"])
or inspect.isasyncgen(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -89,18 +114,29 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -115,18 +151,25 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -141,18 +184,28 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -165,13 +218,15 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## MESSAGES
assert isinstance(messages, list) and isinstance(messages[0], dict)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -184,20 +239,28 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
assert isinstance(
response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)
)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
except:
print(f"Assertion Error: {traceback.format_exc()}")
@ -213,18 +276,24 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, str, dict))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, str, dict))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -236,29 +305,26 @@ def test_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync openai"
}])
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync openai"}],
)
## test streaming
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
stream=True,
)
for chunk in response:
continue
## test failure callback
try:
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
api_key="my-bad-key",
stream=True)
stream=True,
)
for chunk in response:
continue
except:
@ -270,37 +336,36 @@ def test_chat_openai_stream():
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_openai_stream()
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
)
## test streaming
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
stream=True,
)
async for chunk in response:
continue
## test failure callback
try:
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
api_key="my-bad-key",
stream=True)
stream=True,
)
async for chunk in response:
continue
except:
@ -312,36 +377,35 @@ async def test_async_chat_openai_stream():
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_openai_stream())
## Test Azure + sync
def test_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}])
response = litellm.completion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
)
# test streaming
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}],
stream=True)
response = litellm.completion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
stream=True,
)
for chunk in response:
continue
# test failure callback
try:
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}],
response = litellm.completion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
api_key="my-bad-key",
stream=True)
stream=True,
)
for chunk in response:
continue
except:
@ -353,37 +417,36 @@ def test_chat_azure_stream():
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_azure_stream()
## Test Azure + Async
@pytest.mark.asyncio
async def test_async_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}])
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
)
## test streaming
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}],
stream=True)
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
stream=True,
)
async for chunk in response:
continue
## test failure callback
try:
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}],
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
api_key="my-bad-key",
stream=True)
stream=True,
)
async for chunk in response:
continue
except:
@ -395,36 +458,35 @@ async def test_async_chat_azure_stream():
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_azure_stream())
## Test Bedrock + sync
def test_chat_bedrock_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync bedrock"
}])
response = litellm.completion(
model="bedrock/anthropic.claude-v1",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}],
)
# test streaming
response = litellm.completion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync bedrock"
}],
stream=True)
response = litellm.completion(
model="bedrock/anthropic.claude-v1",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}],
stream=True,
)
for chunk in response:
continue
# test failure callback
try:
response = litellm.completion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync bedrock"
}],
response = litellm.completion(
model="bedrock/anthropic.claude-v1",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}],
aws_region_name="my-bad-region",
stream=True)
stream=True,
)
for chunk in response:
continue
except:
@ -436,39 +498,38 @@ def test_chat_bedrock_stream():
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_bedrock_stream()
## Test Bedrock + Async
@pytest.mark.asyncio
async def test_async_chat_bedrock_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async bedrock"
}])
response = await litellm.acompletion(
model="bedrock/anthropic.claude-v1",
messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}],
)
# test streaming
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async bedrock"
}],
stream=True)
response = await litellm.acompletion(
model="bedrock/anthropic.claude-v1",
messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}],
stream=True,
)
print(f"response: {response}")
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async bedrock"
}],
response = await litellm.acompletion(
model="bedrock/anthropic.claude-v1",
messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}],
aws_region_name="my-bad-key",
stream=True)
stream=True,
)
async for chunk in response:
continue
except:
@ -480,8 +541,10 @@ async def test_async_chat_bedrock_stream():
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_bedrock_stream())
# EMBEDDING
## Test OpenAI + Async
@pytest.mark.asyncio
@ -490,8 +553,9 @@ async def test_async_embedding_openai():
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(model="azure/azure-embedding-model",
input=["good morning from litellm"])
response = await litellm.aembedding(
model="azure/azure-embedding-model", input=["good morning from litellm"]
)
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
@ -500,9 +564,11 @@ async def test_async_embedding_openai():
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="text-embedding-ada-002",
response = await litellm.aembedding(
model="text-embedding-ada-002",
input=["good morning from litellm"],
api_key="my-bad-key")
api_key="my-bad-key",
)
except:
pass
await asyncio.sleep(1)
@ -513,8 +579,10 @@ async def test_async_embedding_openai():
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_openai())
## Test Azure + Async
@pytest.mark.asyncio
async def test_async_embedding_azure():
@ -522,8 +590,9 @@ async def test_async_embedding_azure():
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(model="azure/azure-embedding-model",
input=["good morning from litellm"])
response = await litellm.aembedding(
model="azure/azure-embedding-model", input=["good morning from litellm"]
)
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
@ -532,9 +601,11 @@ async def test_async_embedding_azure():
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="azure/azure-embedding-model",
response = await litellm.aembedding(
model="azure/azure-embedding-model",
input=["good morning from litellm"],
api_key="my-bad-key")
api_key="my-bad-key",
)
except:
pass
await asyncio.sleep(1)
@ -545,8 +616,10 @@ async def test_async_embedding_azure():
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_azure())
## Test Bedrock + Async
@pytest.mark.asyncio
async def test_async_embedding_bedrock():
@ -555,8 +628,11 @@ async def test_async_embedding_bedrock():
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
litellm.set_verbose = True
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2")
response = await litellm.aembedding(
model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"],
aws_region_name="os.environ/AWS_REGION_NAME_2",
)
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
@ -565,9 +641,11 @@ async def test_async_embedding_bedrock():
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
response = await litellm.aembedding(
model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"],
aws_region_name="my-bad-region")
aws_region_name="my-bad-region",
)
except:
pass
await asyncio.sleep(1)
@ -578,54 +656,72 @@ async def test_async_embedding_bedrock():
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_bedrock())
# CACHING
## Test Azure - completion, embedding
@pytest.mark.asyncio
async def test_async_completion_azure_caching():
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
response1 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
caching=True,
)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
response2 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
caching=True,
)
await asyncio.sleep(1) # success callbacks are done in parallel
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
print(
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
)
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success
@pytest.mark.asyncio
async def test_async_embedding_azure_caching():
print("Testing custom callback input - Azure Caching")
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
response1 = await litellm.aembedding(model="azure/azure-embedding-model",
response1 = await litellm.aembedding(
model="azure/azure-embedding-model",
input=[f"good morning from litellm1 {unique_time}"],
caching=True)
caching=True,
)
await asyncio.sleep(1) # set cache is async for aembedding()
response2 = await litellm.aembedding(model="azure/azure-embedding-model",
response2 = await litellm.aembedding(
model="azure/azure-embedding-model",
input=[f"good morning from litellm1 {unique_time}"],
caching=True)
caching=True,
)
await asyncio.sleep(1) # success callbacks are done in parallel
print(customHandler_caching.states)
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success
# asyncio.run(
# test_async_embedding_azure_caching()
# )

View file

@ -3,7 +3,8 @@
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List
from litellm import Router, Cache
import litellm
@ -29,39 +30,61 @@ from litellm.integrations.custom_logger import CustomLogger
## 1. router.completion() + router.embeddings()
## 2. proxy.completions + proxy.embeddings
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
class CompletionCustomHandler(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
self.states: Optional[
List[
Literal[
"sync_pre_api_call",
"async_pre_api_call",
"post_api_call",
"sync_stream",
"async_stream",
"sync_success",
"async_success",
"sync_failure",
"async_failure",
]
]
] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
print(f'received kwargs in pre-input: {kwargs}')
print(f"received kwargs in pre-input: {kwargs}")
self.states.append("sync_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(
kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
@ -77,26 +100,36 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.iscoroutine(kwargs["original_response"])
or inspect.isasyncgen(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(
kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except:
print(f"Assertion Error: {traceback.format_exc()}")
@ -112,18 +145,29 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -138,18 +182,25 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
except:
print(f"Assertion Error: {traceback.format_exc()}")
@ -165,18 +216,28 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -200,20 +261,28 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
assert isinstance(
response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)
)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
@ -221,8 +290,12 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(
kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except:
print(f"Assertion Error: {traceback.format_exc()}")
@ -239,22 +312,30 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, str, dict))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, str, dict))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
# Simple Azure OpenAI call
## COMPLETION
@pytest.mark.asyncio
@ -271,37 +352,39 @@ async def test_async_chat_azure():
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800
"rpm": 1800,
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
)
await asyncio.sleep(2)
assert len(customHandler_completion_azure_router.errors) == 0
assert len(customHandler_completion_azure_router.states) == 3 # pre, post, success
assert (
len(customHandler_completion_azure_router.states) == 3
) # pre, post, success
# streaming
litellm.callbacks = [customHandler_streaming_azure_router]
router2 = Router(model_list=model_list) # type: ignore
response = await router2.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
response = await router2.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
stream=True,
)
async for chunk in response:
print(f"async azure router chunk: {chunk}")
continue
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
assert len(customHandler_streaming_azure_router.errors) == 0
assert len(customHandler_streaming_azure_router.states) >= 4 # pre, post, stream (multiple times), success
assert (
len(customHandler_streaming_azure_router.states) >= 4
) # pre, post, stream (multiple times), success
# failure
model_list = [
{
@ -310,20 +393,19 @@ async def test_async_chat_azure():
"model": "azure/chatgpt-v-2",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800
"rpm": 1800,
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
response = await router3.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
)
print(f"response in router3 acompletion: {response}")
except:
pass
@ -335,6 +417,8 @@ async def test_async_chat_azure():
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure())
## EMBEDDING
@pytest.mark.asyncio
@ -350,15 +434,16 @@ async def test_async_embedding_azure():
"model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800
"rpm": 1800,
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.aembedding(model="azure-embedding-model",
input=["hello from litellm!"])
response = await router.aembedding(
model="azure-embedding-model", input=["hello from litellm!"]
)
await asyncio.sleep(2)
assert len(customHandler.errors) == 0
assert len(customHandler.states) == 3 # pre, post, success
@ -370,17 +455,18 @@ async def test_async_embedding_azure():
"model": "azure/azure-embedding-model",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800
"rpm": 1800,
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.aembedding(model="azure-embedding-model",
input=["hello from litellm!"])
response = await router3.aembedding(
model="azure-embedding-model", input=["hello from litellm!"]
)
print(f"response in router3 aembedding: {response}")
except:
pass
@ -392,6 +478,8 @@ async def test_async_embedding_azure():
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_embedding_azure())
# Azure OpenAI call w/ Fallbacks
## COMPLETION
@ -408,10 +496,10 @@ async def test_async_chat_azure_with_fallbacks():
"model": "azure/chatgpt-v-2",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo-16k",
@ -419,31 +507,40 @@ async def test_async_chat_azure_with_fallbacks():
"model": "gpt-3.5-turbo-16k",
},
"tpm": 240000,
"rpm": 1800
}
"rpm": 1800,
},
]
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
)
await asyncio.sleep(2)
print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}")
assert len(customHandler_fallbacks.errors) == 0
assert len(customHandler_fallbacks.states) == 6 # pre, post, failure, pre, post, success
assert (
len(customHandler_fallbacks.states) == 6
) # pre, post, failure, pre, post, success
litellm.callbacks = []
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure_with_fallbacks())
# CACHING
## Test Azure - completion, embedding
@pytest.mark.asyncio
async def test_async_completion_azure_caching():
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
model_list = [
@ -453,10 +550,10 @@ async def test_async_completion_azure_caching():
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo-16k",
@ -464,25 +561,25 @@ async def test_async_completion_azure_caching():
"model": "gpt-3.5-turbo-16k",
},
"tpm": 240000,
"rpm": 1800
}
"rpm": 1800,
},
]
router = Router(model_list=model_list) # type: ignore
response1 = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
response1 = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
caching=True,
)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
response2 = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
caching=True,
)
await asyncio.sleep(1) # success callbacks are done in parallel
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
print(
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
)
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success

View file

@ -1,12 +1,14 @@
import sys
import os
import io, asyncio
# import logging
# logging.basicConfig(level=logging.DEBUG)
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
from litellm import completion
import litellm
litellm.num_retries = 3
import time, random
@ -29,11 +31,14 @@ def pre_request():
import re
def verify_log_file(log_file_path):
with open(log_file_path, 'r') as log_file:
def verify_log_file(log_file_path):
with open(log_file_path, "r") as log_file:
log_content = log_file.read()
print(f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content)
print(
f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content
)
# Define the pattern to search for in the log file
pattern = r"Response from DynamoDB:{.*?}"
@ -50,7 +55,11 @@ def verify_log_file(log_file_path):
print(f"Total occurrences of specified response: {len(matches)}")
# Count the occurrences of successful responses (status code 200 or 201)
success_count = sum(1 for match in matches if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match)
success_count = sum(
1
for match in matches
if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match
)
# Print the count of successful responses
print(f"Count of successful responses from DynamoDB: {success_count}")
@ -69,41 +78,41 @@ def test_dynamo_logging():
litellm.set_verbose = True
original_stdout, log_file, file_name = pre_request()
print("Testing async dynamoDB logging")
async def _test():
return await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}],
messages=[{"role": "user", "content": "This is a test"}],
max_tokens=100,
temperature=0.7,
user = "ishaan-2"
user="ishaan-2",
)
response = asyncio.run(_test())
print(f"response: {response}")
# streaming + async
async def _test2():
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}],
messages=[{"role": "user", "content": "This is a test"}],
max_tokens=10,
temperature=0.7,
user = "ishaan-2",
stream=True
user="ishaan-2",
stream=True,
)
async for chunk in response:
pass
asyncio.run(_test2())
# aembedding()
async def _test3():
return await litellm.aembedding(
model="text-embedding-ada-002",
input = ["hi"],
user = "ishaan-2"
model="text-embedding-ada-002", input=["hi"], user="ishaan-2"
)
response = asyncio.run(_test3())
time.sleep(1)
except Exception as e:
@ -117,4 +126,5 @@ def test_dynamo_logging():
verify_log_file(file_name)
print("Passed! Testing async dynamoDB logging")
# test_dynamo_logging_async()

View file

@ -14,17 +14,18 @@ from litellm import embedding, completion
litellm.set_verbose = False
def test_openai_embedding():
try:
litellm.set_verbose=True
litellm.set_verbose = True
response = embedding(
model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"],
metadata = {"anything": "good day"}
metadata={"anything": "good day"},
)
litellm_response = dict(response)
litellm_response_keys = set(litellm_response.keys())
litellm_response_keys.discard('_response_ms')
litellm_response_keys.discard("_response_ms")
print(litellm_response_keys)
print("LiteLLM Response\n")
@ -32,21 +33,30 @@ def test_openai_embedding():
# same request with OpenAI 1.0+
import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'])
client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
response = client.embeddings.create(
model="text-embedding-ada-002", input=["good morning from litellm", "this is another item"]
model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"],
)
response = dict(response)
openai_response_keys = set(response.keys())
print(openai_response_keys)
assert litellm_response_keys == openai_response_keys # ENSURE the Keys in litellm response is exactly what the openai package returns
assert len(litellm_response["data"]) == 2 # expect two embedding responses from litellm_response since input had two
assert (
litellm_response_keys == openai_response_keys
) # ENSURE the Keys in litellm response is exactly what the openai package returns
assert (
len(litellm_response["data"]) == 2
) # expect two embedding responses from litellm_response since input had two
print(openai_response_keys)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_openai_embedding()
def test_openai_azure_embedding_simple():
try:
response = embedding(
@ -55,12 +65,15 @@ def test_openai_azure_embedding_simple():
)
print(response)
response_keys = set(dict(response).keys())
response_keys.discard('_response_ms')
assert set(["usage", "model", "object", "data"]) == set(response_keys) #assert litellm response has expected keys from OpenAI embedding response
response_keys.discard("_response_ms")
assert set(["usage", "model", "object", "data"]) == set(
response_keys
) # assert litellm response has expected keys from OpenAI embedding response
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_openai_azure_embedding_simple()
@ -69,41 +82,50 @@ def test_openai_azure_embedding_timeouts():
response = embedding(
model="azure/azure-embedding-model",
input=["good morning from litellm"],
timeout=0.00001
timeout=0.00001,
)
print(response)
except openai.APITimeoutError:
print("Good job got timeout error!")
pass
except Exception as e:
pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}")
pytest.fail(
f"Expected timeout error, did not get the correct error. Instead got {e}"
)
# test_openai_azure_embedding_timeouts()
def test_openai_embedding_timeouts():
try:
response = embedding(
model="text-embedding-ada-002",
input=["good morning from litellm"],
timeout=0.00001
timeout=0.00001,
)
print(response)
except openai.APITimeoutError:
print("Good job got OpenAI timeout error!")
pass
except Exception as e:
pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}")
pytest.fail(
f"Expected timeout error, did not get the correct error. Instead got {e}"
)
# test_openai_embedding_timeouts()
def test_openai_azure_embedding():
try:
api_key = os.environ['AZURE_API_KEY']
api_base = os.environ['AZURE_API_BASE']
api_version = os.environ['AZURE_API_VERSION']
api_key = os.environ["AZURE_API_KEY"]
api_base = os.environ["AZURE_API_BASE"]
api_version = os.environ["AZURE_API_VERSION"]
os.environ['AZURE_API_VERSION'] = ""
os.environ['AZURE_API_BASE'] = ""
os.environ['AZURE_API_KEY'] = ""
os.environ["AZURE_API_VERSION"] = ""
os.environ["AZURE_API_BASE"] = ""
os.environ["AZURE_API_KEY"] = ""
response = embedding(
model="azure/azure-embedding-model",
@ -114,33 +136,37 @@ def test_openai_azure_embedding():
)
print(response)
os.environ['AZURE_API_VERSION'] = api_version
os.environ['AZURE_API_BASE'] = api_base
os.environ['AZURE_API_KEY'] = api_key
os.environ["AZURE_API_VERSION"] = api_version
os.environ["AZURE_API_BASE"] = api_base
os.environ["AZURE_API_KEY"] = api_key
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_openai_azure_embedding()
# test_openai_embedding()
def test_cohere_embedding():
try:
# litellm.set_verbose=True
response = embedding(
model="embed-english-v2.0", input=["good morning from litellm", "this is another item"]
model="embed-english-v2.0",
input=["good morning from litellm", "this is another item"],
)
print(f"response:", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_cohere_embedding()
def test_cohere_embedding3():
try:
litellm.set_verbose=True
litellm.set_verbose = True
response = embedding(
model="embed-english-v3.0",
input=["good morning from litellm", "this is another item"],
@ -149,97 +175,135 @@ def test_cohere_embedding3():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_cohere_embedding3()
def test_bedrock_embedding_titan():
try:
litellm.set_verbose=True
litellm.set_verbose = True
response = embedding(
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
"lets test a second string for good measure"]
model="amazon.titan-embed-text-v1",
input=[
"good morning from litellm, attempting to embed data",
"lets test a second string for good measure",
],
)
print(f"response:", response)
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list"
print(f"type of first embedding:", type(response['data'][0]['embedding'][0]))
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
assert isinstance(
response["data"][0]["embedding"], list
), "Expected response to be a list"
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
assert all(
isinstance(x, float) for x in response["data"][0]["embedding"]
), "Expected response to be a list of floats"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_bedrock_embedding_titan()
def test_bedrock_embedding_cohere():
try:
litellm.set_verbose=False
litellm.set_verbose = False
response = embedding(
model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"],
aws_region_name="os.environ/AWS_REGION_NAME_2"
model="cohere.embed-multilingual-v3",
input=[
"good morning from litellm, attempting to embed data",
"lets test a second string for good measure",
],
aws_region_name="os.environ/AWS_REGION_NAME_2",
)
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list"
print(f"type of first embedding:", type(response['data'][0]['embedding'][0]))
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
assert isinstance(
response["data"][0]["embedding"], list
), "Expected response to be a list"
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
assert all(
isinstance(x, float) for x in response["data"][0]["embedding"]
), "Expected response to be a list of floats"
# print(f"response:", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_bedrock_embedding_cohere()
# comment out hf tests - since hf endpoints are unstable
def test_hf_embedding():
try:
# huggingface/microsoft/codebert-base
# huggingface/facebook/bart-large
response = embedding(
model="huggingface/sentence-transformers/all-MiniLM-L6-v2", input=["good morning from litellm", "this is another item"]
model="huggingface/sentence-transformers/all-MiniLM-L6-v2",
input=["good morning from litellm", "this is another item"],
)
print(f"response:", response)
except Exception as e:
# Note: Huggingface inference API is unstable and fails with "model loading errors all the time"
pass
# test_hf_embedding()
# test async embeddings
def test_aembedding():
try:
import asyncio
async def embedding_call():
try:
response = await litellm.aembedding(
model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"]
input=["good morning from litellm", "this is another item"],
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
asyncio.run(embedding_call())
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_aembedding()
def test_aembedding_azure():
try:
import asyncio
async def embedding_call():
try:
response = await litellm.aembedding(
model="azure/azure-embedding-model",
input=["good morning from litellm", "this is another item"]
input=["good morning from litellm", "this is another item"],
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
asyncio.run(embedding_call())
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_aembedding_azure()
def test_sagemaker_embeddings():
try:
response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"])
response = litellm.embedding(
model="sagemaker/berri-benchmarking-gpt-j-6b-fp16",
input=["good morning from litellm", "this is another item"],
)
print(f"response: {response}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_sagemaker_embeddings()
# def local_proxy_embeddings():
# litellm.set_verbose=True

View file

@ -11,17 +11,18 @@ import litellm
from litellm import (
embedding,
completion,
# AuthenticationError,
# AuthenticationError,
ContextWindowExceededError,
# RateLimitError,
# ServiceUnavailableError,
# OpenAIError,
# RateLimitError,
# ServiceUnavailableError,
# OpenAIError,
)
from concurrent.futures import ThreadPoolExecutor
import pytest
litellm.vertex_project = "pathrise-convert-1606954137718"
litellm.vertex_location = "us-central1"
litellm.num_retries=0
litellm.num_retries = 0
# litellm.failure_callback = ["sentry"]
#### What this tests ####
@ -36,6 +37,7 @@ litellm.num_retries=0
models = ["command-nightly"]
# Test 1: Context Window Errors
@pytest.mark.parametrize("model", models)
def test_context_window(model):
@ -56,13 +58,23 @@ def test_context_window(model):
print(f"{e}")
pytest.fail(f"An error occcurred - {e}")
@pytest.mark.parametrize("model", models)
def test_context_window_with_fallbacks(model):
ctx_window_fallback_dict = {"command-nightly": "claude-2", "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k", "azure/chatgpt-v-2": "gpt-3.5-turbo-16k"}
ctx_window_fallback_dict = {
"command-nightly": "claude-2",
"gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k",
"azure/chatgpt-v-2": "gpt-3.5-turbo-16k",
}
sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}]
completion(model=model, messages=messages, context_window_fallback_dict=ctx_window_fallback_dict)
completion(
model=model,
messages=messages,
context_window_fallback_dict=ctx_window_fallback_dict,
)
# for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model)
@ -98,7 +110,9 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["AI21_API_KEY"] = "bad-key"
elif "togethercomputer" in model:
temporary_key = os.environ["TOGETHERAI_API_KEY"]
os.environ["TOGETHERAI_API_KEY"] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a"
os.environ[
"TOGETHERAI_API_KEY"
] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a"
elif model in litellm.openrouter_models:
temporary_key = os.environ["OPENROUTER_API_KEY"]
os.environ["OPENROUTER_API_KEY"] = "bad-key"
@ -115,9 +129,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
temporary_key = os.environ["REPLICATE_API_KEY"]
os.environ["REPLICATE_API_KEY"] = "bad-key"
print(f"model: {model}")
response = completion(
model=model, messages=messages
)
response = completion(model=model, messages=messages)
print(f"response: {response}")
except AuthenticationError as e:
print(f"AuthenticationError Caught Exception - {str(e)}")
@ -148,7 +160,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["REPLICATE_API_KEY"] = temporary_key
elif "j2" in model:
os.environ["AI21_API_KEY"] = temporary_key
elif ("togethercomputer" in model):
elif "togethercomputer" in model:
os.environ["TOGETHERAI_API_KEY"] = temporary_key
elif model in litellm.aleph_alpha_models:
os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key
@ -160,10 +172,12 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key
return
# for model in litellm.models_by_provider["bedrock"]:
# invalid_auth(model=model)
# invalid_auth(model="command-nightly")
# Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models)
def test_invalid_request_error(model):
@ -173,23 +187,18 @@ def test_invalid_request_error(model):
completion(model=model, messages=messages, max_tokens="hello world")
def test_completion_azure_exception():
try:
import openai
print("azure gpt-3.5 test\n\n")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
old_azure_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "good morning"
response = completion(
model="azure/chatgpt-v-2",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
os.environ["AZURE_API_KEY"] = old_azure_key
print(f"response: {response}")
@ -199,25 +208,24 @@ def test_completion_azure_exception():
print("good job got the correct error for azure when key not set")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_azure_exception()
async def asynctest_completion_azure_exception():
try:
import openai
import litellm
print("azure gpt-3.5 test\n\n")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
old_azure_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "good morning"
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
print(f"response: {response}")
print(response)
@ -229,6 +237,8 @@ async def asynctest_completion_azure_exception():
print("Got wrong exception")
print("exception", e)
pytest.fail(f"Error occurred: {e}")
# import asyncio
# asyncio.run(
# asynctest_completion_azure_exception()
@ -239,19 +249,17 @@ def asynctest_completion_openai_exception_bad_model():
try:
import openai
import litellm, asyncio
print("azure exception bad model\n\n")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
async def test():
response = await litellm.acompletion(
model="openai/gpt-6",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
asyncio.run(test())
except openai.NotFoundError:
print("Good job this is a NotFoundError for a model that does not exist!")
@ -261,27 +269,25 @@ def asynctest_completion_openai_exception_bad_model():
assert isinstance(e, openai.BadRequestError)
pytest.fail(f"Error occurred: {e}")
# asynctest_completion_openai_exception_bad_model()
# asynctest_completion_openai_exception_bad_model()
def asynctest_completion_azure_exception_bad_model():
try:
import openai
import litellm, asyncio
print("azure exception bad model\n\n")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
async def test():
response = await litellm.acompletion(
model="azure/gpt-12",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
asyncio.run(test())
except openai.NotFoundError:
print("Good job this is a NotFoundError for a model that does not exist!")
@ -290,25 +296,23 @@ def asynctest_completion_azure_exception_bad_model():
print("Raised wrong type of exception", type(e))
pytest.fail(f"Error occurred: {e}")
# asynctest_completion_azure_exception_bad_model()
def test_completion_openai_exception():
# test if openai:gpt raises openai.AuthenticationError
try:
import openai
print("openai gpt-3.5 test\n\n")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
old_azure_key = os.environ["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = "good morning"
response = completion(
model="gpt-4",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
print(f"response: {response}")
print(response)
@ -317,25 +321,24 @@ def test_completion_openai_exception():
print("OpenAI: good job got the correct error for openai when key not set")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_openai_exception()
def test_completion_mistral_exception():
# test if mistral/mistral-tiny raises openai.AuthenticationError
try:
import openai
print("Testing mistral ai exception mapping")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
old_azure_key = os.environ["MISTRAL_API_KEY"]
os.environ["MISTRAL_API_KEY"] = "good morning"
response = completion(
model="mistral/mistral-tiny",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
print(f"response: {response}")
print(response)
@ -344,11 +347,11 @@ def test_completion_mistral_exception():
print("good job got the correct error for openai when key not set")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_mistral_exception()
# # test_invalid_request_error(model="command-nightly")
# # Test 3: Rate Limit Errors
# def test_model_call(model):

View file

@ -13,6 +13,7 @@ import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
import pytest
litellm.num_retries = 0
litellm.cache = None
# litellm.set_verbose=True
@ -20,23 +21,32 @@ import json
# litellm.success_callback = ["langfuse"]
def get_current_weather(location, unit="fahrenheit"):
"""Get the current weather in a given location"""
if "tokyo" in location.lower():
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"})
elif "san francisco" in location.lower():
return json.dumps({"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"})
return json.dumps(
{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}
)
elif "paris" in location.lower():
return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"})
else:
return json.dumps({"location": location, "temperature": "unknown"})
# Example dummy function hard coded to return the same weather
# In production, this could be your backend API or an external API
def test_parallel_function_call():
try:
# Step 1: send the conversation and available functions to the model
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris?",
}
]
tools = [
{
"type": "function",
@ -50,7 +60,10 @@ def test_parallel_function_call():
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
@ -69,7 +82,9 @@ def test_parallel_function_call():
print("length of tool calls", len(tool_calls))
print("Expecting there to be 3 tool calls")
assert len(tool_calls) > 1 # this has to call the function for SF, Tokyo and parise
assert (
len(tool_calls) > 1
) # this has to call the function for SF, Tokyo and parise
# Step 2: check if the model wanted to call a function
if tool_calls:
@ -78,7 +93,9 @@ def test_parallel_function_call():
available_functions = {
"get_current_weather": get_current_weather,
} # only one function in this example, but you can have multiple
messages.append(response_message) # extend conversation with assistant's reply
messages.append(
response_message
) # extend conversation with assistant's reply
print("Response message\n", response_message)
# Step 4: send the info for each function call and function response to the model
for tool_call in tool_calls:
@ -99,25 +116,26 @@ def test_parallel_function_call():
) # extend conversation with function response
print(f"messages: {messages}")
second_response = litellm.completion(
model="gpt-3.5-turbo-1106",
messages=messages,
temperature=0.2,
seed=22
model="gpt-3.5-turbo-1106", messages=messages, temperature=0.2, seed=22
) # get a new response from the model where it can see the function response
print("second response\n", second_response)
return second_response
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_parallel_function_call()
def test_parallel_function_call_stream():
try:
# Step 1: send the conversation and available functions to the model
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris?",
}
]
tools = [
{
"type": "function",
@ -131,7 +149,10 @@ def test_parallel_function_call_stream():
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
@ -144,7 +165,7 @@ def test_parallel_function_call_stream():
tools=tools,
stream=True,
tool_choice="auto", # auto is default, but we'll be explicit
complete_response = True
complete_response=True,
)
print("Response\n", response)
# for chunk in response:
@ -154,7 +175,9 @@ def test_parallel_function_call_stream():
print("length of tool calls", len(tool_calls))
print("Expecting there to be 3 tool calls")
assert len(tool_calls) > 1 # this has to call the function for SF, Tokyo and parise
assert (
len(tool_calls) > 1
) # this has to call the function for SF, Tokyo and parise
# Step 2: check if the model wanted to call a function
if tool_calls:
@ -163,7 +186,9 @@ def test_parallel_function_call_stream():
available_functions = {
"get_current_weather": get_current_weather,
} # only one function in this example, but you can have multiple
messages.append(response_message) # extend conversation with assistant's reply
messages.append(
response_message
) # extend conversation with assistant's reply
print("Response message\n", response_message)
# Step 4: send the info for each function call and function response to the model
for tool_call in tool_calls:
@ -184,14 +209,12 @@ def test_parallel_function_call_stream():
) # extend conversation with function response
print(f"messages: {messages}")
second_response = litellm.completion(
model="gpt-3.5-turbo-1106",
messages=messages,
temperature=0.2,
seed=22
model="gpt-3.5-turbo-1106", messages=messages, temperature=0.2, seed=22
) # get a new response from the model where it can see the function response
print("second response\n", second_response)
return second_response
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_parallel_function_call_stream()

View file

@ -11,9 +11,11 @@ sys.path.insert(
import pytest
import litellm
def test_get_llm_provider():
_, response, _, _ = litellm.get_llm_provider(model="anthropic.claude-v2:1")
assert response == "bedrock"
test_get_llm_provider()

View file

@ -9,41 +9,60 @@ import litellm
from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models
import pytest
def test_get_gpt3_tokens():
max_tokens = get_max_tokens("gpt-3.5-turbo")
print(max_tokens)
assert max_tokens==4097
assert max_tokens == 4097
# print(results)
test_get_gpt3_tokens()
def test_get_palm_tokens():
# # 🦄🦄🦄🦄🦄🦄🦄🦄
max_tokens = get_max_tokens("palm/chat-bison")
assert max_tokens == 4096
print(max_tokens)
test_get_palm_tokens()
def test_zephyr_hf_tokens():
max_tokens = get_max_tokens("huggingface/HuggingFaceH4/zephyr-7b-beta")
print(max_tokens)
assert max_tokens == 32768
test_zephyr_hf_tokens()
def test_cost_ft_gpt_35():
try:
# this tests if litellm.completion_cost can calculate cost for ft:gpt-3.5-turbo:my-org:custom_suffix:id
# it needs to lookup ft:gpt-3.5-turbo in the litellm model_cost map to get the correct cost
from litellm import ModelResponse, Choices, Message
from litellm.utils import Usage
resp = ModelResponse(
id='chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac',
choices=[Choices(finish_reason=None, index=0,
message=Message(content=' Sure! Here is a short poem about the sky:\n\nA canvas of blue, a', role='assistant'))],
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac",
choices=[
Choices(
finish_reason=None,
index=0,
message=Message(
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
role="assistant",
),
)
],
created=1700775391,
model='ft:gpt-3.5-turbo:my-org:custom_suffix:id',
object='chat.completion', system_fingerprint=None,
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38)
model="ft:gpt-3.5-turbo:my-org:custom_suffix:id",
object="chat.completion",
system_fingerprint=None,
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38),
)
cost = litellm.completion_cost(completion_response=resp)
@ -51,35 +70,58 @@ def test_cost_ft_gpt_35():
input_cost = model_cost["ft:gpt-3.5-turbo"]["input_cost_per_token"]
output_cost = model_cost["ft:gpt-3.5-turbo"]["output_cost_per_token"]
print(input_cost, output_cost)
expected_cost = (input_cost*resp.usage.prompt_tokens) + (output_cost*resp.usage.completion_tokens)
expected_cost = (input_cost * resp.usage.prompt_tokens) + (
output_cost * resp.usage.completion_tokens
)
print("\n Excpected cost", expected_cost)
assert cost == expected_cost
except Exception as e:
pytest.fail(f"Cost Calc failed for ft:gpt-3.5. Expected {expected_cost}, Calculated cost {cost}")
pytest.fail(
f"Cost Calc failed for ft:gpt-3.5. Expected {expected_cost}, Calculated cost {cost}"
)
test_cost_ft_gpt_35()
def test_cost_azure_gpt_35():
try:
# this tests if litellm.completion_cost can calculate cost for azure/chatgpt-deployment-2 which maps to azure/gpt-3.5-turbo
# for this test we check if passing `model` to completion_cost overrides the completion cost
from litellm import ModelResponse, Choices, Message
from litellm.utils import Usage
resp = ModelResponse(
id='chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac',
choices=[Choices(finish_reason=None, index=0,
message=Message(content=' Sure! Here is a short poem about the sky:\n\nA canvas of blue, a', role='assistant'))],
model='azure/gpt-35-turbo', # azure always has model written like this
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38)
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac",
choices=[
Choices(
finish_reason=None,
index=0,
message=Message(
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
role="assistant",
),
)
],
model="azure/gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38),
)
cost = litellm.completion_cost(completion_response=resp, model="azure/gpt-3.5-turbo")
cost = litellm.completion_cost(
completion_response=resp, model="azure/gpt-3.5-turbo"
)
print("\n Calculated Cost for azure/gpt-3.5-turbo", cost)
input_cost = model_cost["azure/gpt-3.5-turbo"]["input_cost_per_token"]
output_cost = model_cost["azure/gpt-3.5-turbo"]["output_cost_per_token"]
expected_cost = (input_cost*resp.usage.prompt_tokens) + (output_cost*resp.usage.completion_tokens)
expected_cost = (input_cost * resp.usage.prompt_tokens) + (
output_cost * resp.usage.completion_tokens
)
print("\n Excpected cost", expected_cost)
assert cost == expected_cost
except Exception as e:
pytest.fail(f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}")
test_cost_azure_gpt_35()
pytest.fail(
f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}"
)
test_cost_azure_gpt_35()

View file

@ -11,19 +11,38 @@ sys.path.insert(
import pytest
from litellm.llms.prompt_templates.factory import prompt_factory
def test_prompt_formatting():
try:
prompt = prompt_factory(model="mistralai/Mistral-7B-Instruct-v0.1", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}])
assert prompt == "<s>[INST] Be a good bot [/INST]</s> [INST] Hello world [/INST]"
prompt = prompt_factory(
model="mistralai/Mistral-7B-Instruct-v0.1",
messages=[
{"role": "system", "content": "Be a good bot"},
{"role": "user", "content": "Hello world"},
],
)
assert (
prompt == "<s>[INST] Be a good bot [/INST]</s> [INST] Hello world [/INST]"
)
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
def test_prompt_formatting_custom_model():
try:
prompt = prompt_factory(model="ehartford/dolphin-2.5-mixtral-8x7b", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}], custom_llm_provider="huggingface")
prompt = prompt_factory(
model="ehartford/dolphin-2.5-mixtral-8x7b",
messages=[
{"role": "system", "content": "Be a good bot"},
{"role": "user", "content": "Hello world"},
],
custom_llm_provider="huggingface",
)
print(f"prompt: {prompt}")
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_prompt_formatting_custom_model()
# def logger_fn(user_model_dict):
# return

View file

@ -5,45 +5,71 @@ import sys, os
import traceback
from dotenv import load_dotenv
import logging
logging.basicConfig(level=logging.DEBUG)
load_dotenv()
import os
import asyncio
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
def test_image_generation_openai():
litellm.set_verbose = True
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
response = litellm.image_generation(
prompt="A cute baby sea otter", model="dall-e-3"
)
print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_openai()
def test_image_generation_azure():
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview")
response = litellm.image_generation(
prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview"
)
print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_azure()
def test_image_generation_azure_dall_e_3():
litellm.set_verbose = True
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test", api_version="2023-12-01-preview", api_base=os.getenv("AZURE_SWEDEN_API_BASE"), api_key=os.getenv("AZURE_SWEDEN_API_KEY"))
response = litellm.image_generation(
prompt="A cute baby sea otter",
model="azure/dall-e-3-test",
api_version="2023-12-01-preview",
api_base=os.getenv("AZURE_SWEDEN_API_BASE"),
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
)
print(f"response: {response}")
assert len(response.data) > 0
# test_image_generation_azure_dall_e_3()
@pytest.mark.asyncio
async def test_async_image_generation_openai():
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
response = litellm.image_generation(
prompt="A cute baby sea otter", model="dall-e-3"
)
print(f"response: {response}")
assert len(response.data) > 0
# asyncio.run(test_async_image_generation_openai())
@pytest.mark.asyncio
async def test_async_image_generation_azure():
response = await litellm.aimage_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test")
response = await litellm.aimage_generation(
prompt="A cute baby sea otter", model="azure/dall-e-3-test"
)
print(f"response: {response}")

View file

@ -104,4 +104,3 @@
# # pytest.fail(f"Error occurred: {e}")
# # test_openai_with_params()

Some files were not shown because too many files have changed in this diff Show more