mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
refactor: add black formatting
This commit is contained in:
parent
b87d630b0a
commit
4905929de3
156 changed files with 19723 additions and 10869 deletions
|
@ -1,9 +1,13 @@
|
|||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: stable
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 3.8.4 # The version of flake8 to use
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/
|
||||
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/|^litellm/proxy/tests/
|
||||
additional_dependencies: [flake8-print]
|
||||
files: litellm/.*\.py
|
||||
- repo: local
|
||||
|
|
|
@ -9,33 +9,37 @@ import os
|
|||
|
||||
# Define the list of models to benchmark
|
||||
# select any LLM listed here: https://docs.litellm.ai/docs/providers
|
||||
models = ['gpt-3.5-turbo', 'claude-2']
|
||||
models = ["gpt-3.5-turbo", "claude-2"]
|
||||
|
||||
# Enter LLM API keys
|
||||
# https://docs.litellm.ai/docs/providers
|
||||
os.environ['OPENAI_API_KEY'] = ""
|
||||
os.environ['ANTHROPIC_API_KEY'] = ""
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
os.environ["ANTHROPIC_API_KEY"] = ""
|
||||
|
||||
# List of questions to benchmark (replace with your questions)
|
||||
questions = [
|
||||
"When will BerriAI IPO?",
|
||||
"When will LiteLLM hit $100M ARR?"
|
||||
]
|
||||
questions = ["When will BerriAI IPO?", "When will LiteLLM hit $100M ARR?"]
|
||||
|
||||
# Enter your system prompt here
|
||||
# Enter your system prompt here
|
||||
system_prompt = """
|
||||
You are LiteLLMs helpful assistant
|
||||
"""
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--system-prompt', default="You are a helpful assistant that can answer questions.", help="System prompt for the conversation.")
|
||||
@click.option(
|
||||
"--system-prompt",
|
||||
default="You are a helpful assistant that can answer questions.",
|
||||
help="System prompt for the conversation.",
|
||||
)
|
||||
def main(system_prompt):
|
||||
for question in questions:
|
||||
data = [] # Data for the current question
|
||||
|
||||
with tqdm(total=len(models)) as pbar:
|
||||
for model in models:
|
||||
colored_description = colored(f"Running question: {question} for model: {model}", 'green')
|
||||
colored_description = colored(
|
||||
f"Running question: {question} for model: {model}", "green"
|
||||
)
|
||||
pbar.set_description(colored_description)
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -44,35 +48,43 @@ def main(system_prompt):
|
|||
max_tokens=500,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": question}
|
||||
{"role": "user", "content": question},
|
||||
],
|
||||
)
|
||||
|
||||
end = time.time()
|
||||
total_time = end - start_time
|
||||
cost = completion_cost(completion_response=response)
|
||||
raw_response = response['choices'][0]['message']['content']
|
||||
raw_response = response["choices"][0]["message"]["content"]
|
||||
|
||||
data.append({
|
||||
'Model': colored(model, 'light_blue'),
|
||||
'Response': raw_response, # Colorize the response
|
||||
'ResponseTime': colored(f"{total_time:.2f} seconds", "red"),
|
||||
'Cost': colored(f"${cost:.6f}", 'green'), # Colorize the cost
|
||||
})
|
||||
data.append(
|
||||
{
|
||||
"Model": colored(model, "light_blue"),
|
||||
"Response": raw_response, # Colorize the response
|
||||
"ResponseTime": colored(f"{total_time:.2f} seconds", "red"),
|
||||
"Cost": colored(f"${cost:.6f}", "green"), # Colorize the cost
|
||||
}
|
||||
)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Separate headers from the data
|
||||
headers = ['Model', 'Response', 'Response Time (seconds)', 'Cost ($)']
|
||||
headers = ["Model", "Response", "Response Time (seconds)", "Cost ($)"]
|
||||
colwidths = [15, 80, 15, 10]
|
||||
|
||||
# Create a nicely formatted table for the current question
|
||||
table = tabulate([list(d.values()) for d in data], headers, tablefmt="grid", maxcolwidths=colwidths)
|
||||
|
||||
table = tabulate(
|
||||
[list(d.values()) for d in data],
|
||||
headers,
|
||||
tablefmt="grid",
|
||||
maxcolwidths=colwidths,
|
||||
)
|
||||
|
||||
# Print the table for the current question
|
||||
colored_question = colored(question, 'green')
|
||||
colored_question = colored(question, "green")
|
||||
click.echo(f"\nBenchmark Results for '{colored_question}':")
|
||||
click.echo(table) # Display the formatted table
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,25 +1,22 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost
|
||||
|
||||
from autoevals.llm import *
|
||||
|
||||
###################
|
||||
import litellm
|
||||
|
||||
# litellm completion call
|
||||
question = "which country has the highest population"
|
||||
response = litellm.completion(
|
||||
model = "gpt-3.5-turbo",
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
}
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": question}],
|
||||
)
|
||||
print(response)
|
||||
# use the auto eval Factuality() evaluator
|
||||
|
@ -27,9 +24,11 @@ print(response)
|
|||
print("calling evaluator")
|
||||
evaluator = Factuality()
|
||||
result = evaluator(
|
||||
output=response.choices[0]["message"]["content"], # response from litellm.completion()
|
||||
expected="India", # expected output
|
||||
input=question # question passed to litellm.completion
|
||||
output=response.choices[0]["message"][
|
||||
"content"
|
||||
], # response from litellm.completion()
|
||||
expected="India", # expected output
|
||||
input=question, # question passed to litellm.completion
|
||||
)
|
||||
|
||||
print(result)
|
||||
print(result)
|
||||
|
|
|
@ -4,9 +4,10 @@ from flask_cors import CORS
|
|||
import traceback
|
||||
import litellm
|
||||
from util import handle_error
|
||||
from litellm import completion
|
||||
import os, dotenv, time
|
||||
from litellm import completion
|
||||
import os, dotenv, time
|
||||
import json
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# TODO: set your keys in .env or here:
|
||||
|
@ -19,57 +20,72 @@ verbose = True
|
|||
|
||||
# litellm.caching_with_models = True # CACHING: caching_with_models Keys in the cache are messages + model. - to learn more: https://docs.litellm.ai/docs/caching/
|
||||
######### PROMPT LOGGING ##########
|
||||
os.environ["PROMPTLAYER_API_KEY"] = "" # set your promptlayer key here - https://promptlayer.com/
|
||||
os.environ[
|
||||
"PROMPTLAYER_API_KEY"
|
||||
] = "" # set your promptlayer key here - https://promptlayer.com/
|
||||
|
||||
# set callbacks
|
||||
litellm.success_callback = ["promptlayer"]
|
||||
############ HELPER FUNCTIONS ###################################
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if verbose:
|
||||
print(print_statement)
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
@app.route('/')
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return 'received!', 200
|
||||
return "received!", 200
|
||||
|
||||
|
||||
def data_generator(response):
|
||||
for chunk in response:
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
@app.route('/chat/completions', methods=["POST"])
|
||||
|
||||
@app.route("/chat/completions", methods=["POST"])
|
||||
def api_completion():
|
||||
data = request.json
|
||||
start_time = time.time()
|
||||
if data.get('stream') == "True":
|
||||
data['stream'] = True # convert to boolean
|
||||
start_time = time.time()
|
||||
if data.get("stream") == "True":
|
||||
data["stream"] = True # convert to boolean
|
||||
try:
|
||||
if "prompt" not in data:
|
||||
raise ValueError("data needs to have prompt")
|
||||
data["model"] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct
|
||||
data[
|
||||
"model"
|
||||
] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct
|
||||
# COMPLETION CALL
|
||||
system_prompt = "Only respond to questions about code. Say 'I don't know' to anything outside of that."
|
||||
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": data.pop("prompt")}]
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": data.pop("prompt")},
|
||||
]
|
||||
data["messages"] = messages
|
||||
print(f"data: {data}")
|
||||
response = completion(**data)
|
||||
## LOG SUCCESS
|
||||
end_time = time.time()
|
||||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return Response(data_generator(response), mimetype='text/event-stream')
|
||||
end_time = time.time()
|
||||
if (
|
||||
"stream" in data and data["stream"] == True
|
||||
): # use generate_responses to stream responses
|
||||
return Response(data_generator(response), mimetype="text/event-stream")
|
||||
except Exception as e:
|
||||
# call handle_error function
|
||||
print_verbose(f"Got Error api_completion(): {traceback.format_exc()}")
|
||||
## LOG FAILURE
|
||||
end_time = time.time()
|
||||
end_time = time.time()
|
||||
traceback_exception = traceback.format_exc()
|
||||
return handle_error(data=data)
|
||||
return response
|
||||
|
||||
@app.route('/get_models', methods=["POST"])
|
||||
|
||||
@app.route("/get_models", methods=["POST"])
|
||||
def get_models():
|
||||
try:
|
||||
return litellm.model_list
|
||||
|
@ -78,7 +94,8 @@ def get_models():
|
|||
response = {"error": str(e)}
|
||||
return response, 200
|
||||
|
||||
if __name__ == "__main__":
|
||||
from waitress import serve
|
||||
serve(app, host="0.0.0.0", port=4000, threads=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from waitress import serve
|
||||
|
||||
serve(app, host="0.0.0.0", port=4000, threads=500)
|
||||
|
|
|
@ -3,27 +3,28 @@ from urllib.parse import urlparse, parse_qs
|
|||
|
||||
|
||||
def get_next_url(response):
|
||||
"""
|
||||
"""
|
||||
Function to get 'next' url from Link header
|
||||
:param response: response from requests
|
||||
:return: next url or None
|
||||
"""
|
||||
if 'link' not in response.headers:
|
||||
return None
|
||||
headers = response.headers
|
||||
if "link" not in response.headers:
|
||||
return None
|
||||
headers = response.headers
|
||||
|
||||
next_url = headers['Link']
|
||||
print(next_url)
|
||||
start_index = next_url.find("<")
|
||||
end_index = next_url.find(">")
|
||||
next_url = headers["Link"]
|
||||
print(next_url)
|
||||
start_index = next_url.find("<")
|
||||
end_index = next_url.find(">")
|
||||
|
||||
return next_url[1:end_index]
|
||||
|
||||
return next_url[1:end_index]
|
||||
|
||||
def get_models(url):
|
||||
"""
|
||||
Function to retrieve all models from paginated endpoint
|
||||
:param url: base url to make GET request
|
||||
:return: list of all models
|
||||
Function to retrieve all models from paginated endpoint
|
||||
:param url: base url to make GET request
|
||||
:return: list of all models
|
||||
"""
|
||||
models = []
|
||||
while url:
|
||||
|
@ -36,19 +37,21 @@ def get_models(url):
|
|||
models.extend(payload)
|
||||
return models
|
||||
|
||||
|
||||
def get_cleaned_models(models):
|
||||
"""
|
||||
Function to clean retrieved models
|
||||
:param models: list of retrieved models
|
||||
:return: list of cleaned models
|
||||
Function to clean retrieved models
|
||||
:param models: list of retrieved models
|
||||
:return: list of cleaned models
|
||||
"""
|
||||
cleaned_models = []
|
||||
for model in models:
|
||||
cleaned_models.append(model["id"])
|
||||
return cleaned_models
|
||||
|
||||
|
||||
# Get text-generation models
|
||||
url = 'https://huggingface.co/api/models?filter=text-generation-inference'
|
||||
url = "https://huggingface.co/api/models?filter=text-generation-inference"
|
||||
text_generation_models = get_models(url)
|
||||
cleaned_text_generation_models = get_cleaned_models(text_generation_models)
|
||||
|
||||
|
@ -56,7 +59,7 @@ print(cleaned_text_generation_models)
|
|||
|
||||
|
||||
# Get conversational models
|
||||
url = 'https://huggingface.co/api/models?filter=conversational'
|
||||
url = "https://huggingface.co/api/models?filter=conversational"
|
||||
conversational_models = get_models(url)
|
||||
cleaned_conversational_models = get_cleaned_models(conversational_models)
|
||||
|
||||
|
@ -65,19 +68,23 @@ print(cleaned_conversational_models)
|
|||
|
||||
def write_to_txt(cleaned_models, filename):
|
||||
"""
|
||||
Function to write the contents of a list to a text file
|
||||
:param cleaned_models: list of cleaned models
|
||||
:param filename: name of the text file
|
||||
Function to write the contents of a list to a text file
|
||||
:param cleaned_models: list of cleaned models
|
||||
:param filename: name of the text file
|
||||
"""
|
||||
with open(filename, 'w') as f:
|
||||
with open(filename, "w") as f:
|
||||
for item in cleaned_models:
|
||||
f.write("%s\n" % item)
|
||||
|
||||
|
||||
# Write contents of cleaned_text_generation_models to text_generation_models.txt
|
||||
write_to_txt(cleaned_text_generation_models, 'huggingface_llms_metadata/hf_text_generation_models.txt')
|
||||
write_to_txt(
|
||||
cleaned_text_generation_models,
|
||||
"huggingface_llms_metadata/hf_text_generation_models.txt",
|
||||
)
|
||||
|
||||
# Write contents of cleaned_conversational_models to conversational_models.txt
|
||||
write_to_txt(cleaned_conversational_models, 'huggingface_llms_metadata/hf_conversational_models.txt')
|
||||
|
||||
|
||||
write_to_txt(
|
||||
cleaned_conversational_models,
|
||||
"huggingface_llms_metadata/hf_conversational_models.txt",
|
||||
)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
import openai
|
||||
|
||||
api_base = f"http://0.0.0.0:8000"
|
||||
|
@ -8,29 +7,29 @@ openai.api_key = "temp-key"
|
|||
print(openai.api_base)
|
||||
|
||||
|
||||
print(f'LiteLLM: response from proxy with streaming')
|
||||
print(f"LiteLLM: response from proxy with streaming")
|
||||
response = openai.ChatCompletion.create(
|
||||
model="ollama/llama2",
|
||||
messages = [
|
||||
model="ollama/llama2",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, acknowledge that you got it"
|
||||
"content": "this is a test request, acknowledge that you got it",
|
||||
}
|
||||
],
|
||||
stream=True
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(f'LiteLLM: streaming response from proxy {chunk}')
|
||||
print(f"LiteLLM: streaming response from proxy {chunk}")
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model="ollama/llama2",
|
||||
messages = [
|
||||
model="ollama/llama2",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, acknowledge that you got it"
|
||||
"content": "this is a test request, acknowledge that you got it",
|
||||
}
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
print(f'LiteLLM: response from proxy {response}')
|
||||
print(f"LiteLLM: response from proxy {response}")
|
||||
|
|
|
@ -12,42 +12,51 @@ import pytest
|
|||
|
||||
from litellm import Router
|
||||
import litellm
|
||||
litellm.set_verbose=False
|
||||
|
||||
litellm.set_verbose = False
|
||||
os.environ.pop("AZURE_AD_TOKEN")
|
||||
|
||||
model_list = [{ # list of model deployments
|
||||
"model_name": "gpt-3.5-turbo", # model alias
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2", # actual model name
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
}
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
}
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
}
|
||||
}]
|
||||
model_list = [
|
||||
{ # list of model deployments
|
||||
"model_name": "gpt-3.5-turbo", # model alias
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2", # actual model name
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
|
||||
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"]
|
||||
file_paths = [
|
||||
"test_questions/question1.txt",
|
||||
"test_questions/question2.txt",
|
||||
"test_questions/question3.txt",
|
||||
]
|
||||
questions = []
|
||||
|
||||
for file_path in file_paths:
|
||||
try:
|
||||
print(file_path)
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
content = file.read()
|
||||
questions.append(content)
|
||||
except FileNotFoundError as e:
|
||||
|
@ -59,10 +68,9 @@ for file_path in file_paths:
|
|||
# print(q)
|
||||
|
||||
|
||||
|
||||
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
|
||||
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
|
||||
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
|
||||
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
|
||||
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
|
||||
|
||||
import concurrent.futures
|
||||
import random
|
||||
|
@ -74,10 +82,18 @@ def make_openai_completion(question):
|
|||
try:
|
||||
start_time = time.time()
|
||||
import openai
|
||||
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000",
|
||||
|
||||
client = openai.OpenAI(
|
||||
api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000"
|
||||
) # base_url="http://0.0.0.0:8000",
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a helpful assistant. Answer this question{question}",
|
||||
}
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
end_time = time.time()
|
||||
|
@ -92,11 +108,10 @@ def make_openai_completion(question):
|
|||
except Exception as e:
|
||||
# Log exceptions for failed calls
|
||||
with open("error_log.txt", "a") as error_log_file:
|
||||
error_log_file.write(
|
||||
f"Question: {question[:100]}\nException: {str(e)}\n\n"
|
||||
)
|
||||
error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
|
||||
return None
|
||||
|
||||
|
||||
# Number of concurrent calls (you can adjust this)
|
||||
concurrent_calls = 100
|
||||
|
||||
|
@ -133,4 +148,3 @@ with open("request_log.txt", "r") as log_file:
|
|||
|
||||
with open("error_log.txt", "r") as error_log_file:
|
||||
print("\nError Log:\n", error_log_file.read())
|
||||
|
||||
|
|
|
@ -12,42 +12,51 @@ import pytest
|
|||
|
||||
from litellm import Router
|
||||
import litellm
|
||||
litellm.set_verbose=False
|
||||
|
||||
litellm.set_verbose = False
|
||||
# os.environ.pop("AZURE_AD_TOKEN")
|
||||
|
||||
model_list = [{ # list of model deployments
|
||||
"model_name": "gpt-3.5-turbo", # model alias
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2", # actual model name
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
}
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
}
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
}
|
||||
}]
|
||||
model_list = [
|
||||
{ # list of model deployments
|
||||
"model_name": "gpt-3.5-turbo", # model alias
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2", # actual model name
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
|
||||
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"]
|
||||
file_paths = [
|
||||
"test_questions/question1.txt",
|
||||
"test_questions/question2.txt",
|
||||
"test_questions/question3.txt",
|
||||
]
|
||||
questions = []
|
||||
|
||||
for file_path in file_paths:
|
||||
try:
|
||||
print(file_path)
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
content = file.read()
|
||||
questions.append(content)
|
||||
except FileNotFoundError as e:
|
||||
|
@ -59,10 +68,9 @@ for file_path in file_paths:
|
|||
# print(q)
|
||||
|
||||
|
||||
|
||||
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
|
||||
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
|
||||
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
|
||||
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
|
||||
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
|
||||
|
||||
import concurrent.futures
|
||||
import random
|
||||
|
@ -76,9 +84,12 @@ def make_openai_completion(question):
|
|||
import requests
|
||||
|
||||
data = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'messages': [
|
||||
{'role': 'system', 'content': f'You are a helpful assistant. Answer this question{question}'},
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a helpful assistant. Answer this question{question}",
|
||||
},
|
||||
],
|
||||
}
|
||||
response = requests.post("http://0.0.0.0:8000/queue/request", json=data)
|
||||
|
@ -89,8 +100,8 @@ def make_openai_completion(question):
|
|||
log_file.write(
|
||||
f"Question: {question[:100]}\nResponse ID: {response.get('id', 'N/A')} Url: {response.get('url', 'N/A')}\nTime: {end_time - start_time:.2f} seconds\n\n"
|
||||
)
|
||||
|
||||
# polling the url
|
||||
|
||||
# polling the url
|
||||
while True:
|
||||
try:
|
||||
url = response["url"]
|
||||
|
@ -107,7 +118,9 @@ def make_openai_completion(question):
|
|||
)
|
||||
|
||||
break
|
||||
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}")
|
||||
print(
|
||||
f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}"
|
||||
)
|
||||
time.sleep(0.5)
|
||||
except Exception as e:
|
||||
print("got exception in polling", e)
|
||||
|
@ -117,11 +130,10 @@ def make_openai_completion(question):
|
|||
except Exception as e:
|
||||
# Log exceptions for failed calls
|
||||
with open("error_log.txt", "a") as error_log_file:
|
||||
error_log_file.write(
|
||||
f"Question: {question[:100]}\nException: {str(e)}\n\n"
|
||||
)
|
||||
error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
|
||||
return None
|
||||
|
||||
|
||||
# Number of concurrent calls (you can adjust this)
|
||||
concurrent_calls = 10
|
||||
|
||||
|
@ -142,7 +154,7 @@ successful_calls = 0
|
|||
failed_calls = 0
|
||||
|
||||
for future in futures:
|
||||
if future.done():
|
||||
if future.done():
|
||||
if future.result() is not None:
|
||||
successful_calls += 1
|
||||
else:
|
||||
|
@ -152,4 +164,3 @@ print(f"Load test Summary:")
|
|||
print(f"Total Requests: {concurrent_calls}")
|
||||
print(f"Successful Calls: {successful_calls}")
|
||||
print(f"Failed Calls: {failed_calls}")
|
||||
|
||||
|
|
|
@ -12,42 +12,51 @@ import pytest
|
|||
|
||||
from litellm import Router
|
||||
import litellm
|
||||
litellm.set_verbose=False
|
||||
|
||||
litellm.set_verbose = False
|
||||
os.environ.pop("AZURE_AD_TOKEN")
|
||||
|
||||
model_list = [{ # list of model deployments
|
||||
"model_name": "gpt-3.5-turbo", # model alias
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2", # actual model name
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
}
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
}
|
||||
}, {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
}
|
||||
}]
|
||||
model_list = [
|
||||
{ # list of model deployments
|
||||
"model_name": "gpt-3.5-turbo", # model alias
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2", # actual model name
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
|
||||
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"]
|
||||
file_paths = [
|
||||
"test_questions/question1.txt",
|
||||
"test_questions/question2.txt",
|
||||
"test_questions/question3.txt",
|
||||
]
|
||||
questions = []
|
||||
|
||||
for file_path in file_paths:
|
||||
try:
|
||||
print(file_path)
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
content = file.read()
|
||||
questions.append(content)
|
||||
except FileNotFoundError as e:
|
||||
|
@ -59,10 +68,9 @@ for file_path in file_paths:
|
|||
# print(q)
|
||||
|
||||
|
||||
|
||||
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
|
||||
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
|
||||
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
|
||||
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
|
||||
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
|
||||
|
||||
import concurrent.futures
|
||||
import random
|
||||
|
@ -75,7 +83,12 @@ def make_openai_completion(question):
|
|||
start_time = time.time()
|
||||
response = router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a helpful assistant. Answer this question{question}",
|
||||
}
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
end_time = time.time()
|
||||
|
@ -90,11 +103,10 @@ def make_openai_completion(question):
|
|||
except Exception as e:
|
||||
# Log exceptions for failed calls
|
||||
with open("error_log.txt", "a") as error_log_file:
|
||||
error_log_file.write(
|
||||
f"Question: {question[:100]}\nException: {str(e)}\n\n"
|
||||
)
|
||||
error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
|
||||
return None
|
||||
|
||||
|
||||
# Number of concurrent calls (you can adjust this)
|
||||
concurrent_calls = 150
|
||||
|
||||
|
@ -131,4 +143,3 @@ with open("request_log.txt", "r") as log_file:
|
|||
|
||||
with open("error_log.txt", "r") as error_log_file:
|
||||
print("\nError Log:\n", error_log_file.read())
|
||||
|
||||
|
|
|
@ -9,9 +9,15 @@ input_callback: List[Union[str, Callable]] = []
|
|||
success_callback: List[Union[str, Callable]] = []
|
||||
failure_callback: List[Union[str, Callable]] = []
|
||||
callbacks: List[Callable] = []
|
||||
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[Union[str, Callable]] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_input_callback: List[
|
||||
Callable
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[
|
||||
Union[str, Callable]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[
|
||||
Callable
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
pre_call_rules: List[Callable] = []
|
||||
post_call_rules: List[Callable] = []
|
||||
email: Optional[
|
||||
|
@ -42,20 +48,88 @@ aleph_alpha_key: Optional[str] = None
|
|||
nlp_cloud_key: Optional[str] = None
|
||||
use_client: bool = False
|
||||
logging: bool = True
|
||||
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
cache: Optional[
|
||||
Cache
|
||||
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
model_alias_map: Dict[str, str] = {}
|
||||
model_group_alias_map: Dict[str, str] = {}
|
||||
max_budget: float = 0.0 # set the max budget across all providers
|
||||
_openai_completion_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||
_litellm_completion_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
|
||||
_current_cost = 0 # private variable, used if max budget is set
|
||||
max_budget: float = 0.0 # set the max budget across all providers
|
||||
_openai_completion_params = [
|
||||
"functions",
|
||||
"function_call",
|
||||
"temperature",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stream",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"request_timeout",
|
||||
"api_base",
|
||||
"api_version",
|
||||
"api_key",
|
||||
"deployment_id",
|
||||
"organization",
|
||||
"base_url",
|
||||
"default_headers",
|
||||
"timeout",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"max_retries",
|
||||
]
|
||||
_litellm_completion_params = [
|
||||
"metadata",
|
||||
"acompletion",
|
||||
"caching",
|
||||
"mock_response",
|
||||
"api_key",
|
||||
"api_version",
|
||||
"api_base",
|
||||
"force_timeout",
|
||||
"logger_fn",
|
||||
"verbose",
|
||||
"custom_llm_provider",
|
||||
"litellm_logging_obj",
|
||||
"litellm_call_id",
|
||||
"use_client",
|
||||
"id",
|
||||
"fallbacks",
|
||||
"azure",
|
||||
"headers",
|
||||
"model_list",
|
||||
"num_retries",
|
||||
"context_window_fallback_dict",
|
||||
"roles",
|
||||
"final_prompt_value",
|
||||
"bos_token",
|
||||
"eos_token",
|
||||
"request_timeout",
|
||||
"complete_response",
|
||||
"self",
|
||||
"client",
|
||||
"rpm",
|
||||
"tpm",
|
||||
"input_cost_per_token",
|
||||
"output_cost_per_token",
|
||||
"hf_model_name",
|
||||
"model_info",
|
||||
"proxy_server_request",
|
||||
"preset_cache_key",
|
||||
]
|
||||
_current_cost = 0 # private variable, used if max budget is set
|
||||
error_logs: Dict = {}
|
||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||
client_session: Optional[httpx.Client] = None
|
||||
aclient_session: Optional[httpx.AsyncClient] = None
|
||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
suppress_debug_info = False
|
||||
dynamodb_table_name: Optional[str] = None
|
||||
|
@ -66,23 +140,35 @@ fallbacks: Optional[List] = None
|
|||
context_window_fallbacks: Optional[List] = None
|
||||
allowed_fails: int = 0
|
||||
####### SECRET MANAGERS #####################
|
||||
secret_manager_client: Optional[Any] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||
secret_manager_client: Optional[
|
||||
Any
|
||||
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||
#############################################
|
||||
|
||||
|
||||
def get_model_cost_map(url: str):
|
||||
try:
|
||||
with requests.get(url, timeout=5) as response: # set a 5 second timeout for the get request
|
||||
response.raise_for_status() # Raise an exception if the request is unsuccessful
|
||||
with requests.get(
|
||||
url, timeout=5
|
||||
) as response: # set a 5 second timeout for the get request
|
||||
response.raise_for_status() # Raise an exception if the request is unsuccessful
|
||||
content = response.json()
|
||||
return content
|
||||
except Exception as e:
|
||||
import importlib.resources
|
||||
import json
|
||||
with importlib.resources.open_text("litellm", "model_prices_and_context_window_backup.json") as f:
|
||||
|
||||
with importlib.resources.open_text(
|
||||
"litellm", "model_prices_and_context_window_backup.json"
|
||||
) as f:
|
||||
content = json.load(f)
|
||||
return content
|
||||
|
||||
|
||||
model_cost = get_model_cost_map(url=model_cost_map_url)
|
||||
custom_prompt_dict:Dict[str, dict] = {}
|
||||
custom_prompt_dict: Dict[str, dict] = {}
|
||||
|
||||
|
||||
####### THREAD-SPECIFIC DATA ###################
|
||||
class MyLocal(threading.local):
|
||||
def __init__(self):
|
||||
|
@ -123,56 +209,51 @@ bedrock_models: List = []
|
|||
deepinfra_models: List = []
|
||||
perplexity_models: List = []
|
||||
for key, value in model_cost.items():
|
||||
if value.get('litellm_provider') == 'openai':
|
||||
if value.get("litellm_provider") == "openai":
|
||||
open_ai_chat_completion_models.append(key)
|
||||
elif value.get('litellm_provider') == 'text-completion-openai':
|
||||
elif value.get("litellm_provider") == "text-completion-openai":
|
||||
open_ai_text_completion_models.append(key)
|
||||
elif value.get('litellm_provider') == 'cohere':
|
||||
elif value.get("litellm_provider") == "cohere":
|
||||
cohere_models.append(key)
|
||||
elif value.get('litellm_provider') == 'anthropic':
|
||||
elif value.get("litellm_provider") == "anthropic":
|
||||
anthropic_models.append(key)
|
||||
elif value.get('litellm_provider') == 'openrouter':
|
||||
elif value.get("litellm_provider") == "openrouter":
|
||||
openrouter_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-text-models':
|
||||
elif value.get("litellm_provider") == "vertex_ai-text-models":
|
||||
vertex_text_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
|
||||
elif value.get("litellm_provider") == "vertex_ai-code-text-models":
|
||||
vertex_code_text_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-language-models':
|
||||
elif value.get("litellm_provider") == "vertex_ai-language-models":
|
||||
vertex_language_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-vision-models':
|
||||
elif value.get("litellm_provider") == "vertex_ai-vision-models":
|
||||
vertex_vision_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-chat-models':
|
||||
elif value.get("litellm_provider") == "vertex_ai-chat-models":
|
||||
vertex_chat_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models':
|
||||
elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
|
||||
vertex_code_chat_models.append(key)
|
||||
elif value.get('litellm_provider') == 'ai21':
|
||||
elif value.get("litellm_provider") == "ai21":
|
||||
ai21_models.append(key)
|
||||
elif value.get('litellm_provider') == 'nlp_cloud':
|
||||
elif value.get("litellm_provider") == "nlp_cloud":
|
||||
nlp_cloud_models.append(key)
|
||||
elif value.get('litellm_provider') == 'aleph_alpha':
|
||||
elif value.get("litellm_provider") == "aleph_alpha":
|
||||
aleph_alpha_models.append(key)
|
||||
elif value.get('litellm_provider') == 'bedrock':
|
||||
elif value.get("litellm_provider") == "bedrock":
|
||||
bedrock_models.append(key)
|
||||
elif value.get('litellm_provider') == 'deepinfra':
|
||||
elif value.get("litellm_provider") == "deepinfra":
|
||||
deepinfra_models.append(key)
|
||||
elif value.get('litellm_provider') == 'perplexity':
|
||||
elif value.get("litellm_provider") == "perplexity":
|
||||
perplexity_models.append(key)
|
||||
|
||||
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
|
||||
openai_compatible_endpoints: List = [
|
||||
"api.perplexity.ai",
|
||||
"api.perplexity.ai",
|
||||
"api.endpoints.anyscale.com/v1",
|
||||
"api.deepinfra.com/v1/openai",
|
||||
"api.mistral.ai/v1"
|
||||
"api.mistral.ai/v1",
|
||||
]
|
||||
|
||||
# this is maintained for Exception Mapping
|
||||
openai_compatible_providers: List = [
|
||||
"anyscale",
|
||||
"mistral",
|
||||
"deepinfra",
|
||||
"perplexity"
|
||||
]
|
||||
openai_compatible_providers: List = ["anyscale", "mistral", "deepinfra", "perplexity"]
|
||||
|
||||
|
||||
# well supported replicate llms
|
||||
|
@ -209,23 +290,18 @@ huggingface_models: List = [
|
|||
together_ai_models: List = [
|
||||
# llama llms - chat
|
||||
"togethercomputer/llama-2-70b-chat",
|
||||
|
||||
# llama llms - language / instruct
|
||||
# llama llms - language / instruct
|
||||
"togethercomputer/llama-2-70b",
|
||||
"togethercomputer/LLaMA-2-7B-32K",
|
||||
"togethercomputer/Llama-2-7B-32K-Instruct",
|
||||
"togethercomputer/llama-2-7b",
|
||||
|
||||
# falcon llms
|
||||
"togethercomputer/falcon-40b-instruct",
|
||||
"togethercomputer/falcon-7b-instruct",
|
||||
|
||||
# alpaca
|
||||
"togethercomputer/alpaca-7b",
|
||||
|
||||
# chat llms
|
||||
"HuggingFaceH4/starchat-alpha",
|
||||
|
||||
# code llms
|
||||
"togethercomputer/CodeLlama-34b",
|
||||
"togethercomputer/CodeLlama-34b-Instruct",
|
||||
|
@ -234,29 +310,27 @@ together_ai_models: List = [
|
|||
"NumbersStation/nsql-llama-2-7B",
|
||||
"WizardLM/WizardCoder-15B-V1.0",
|
||||
"WizardLM/WizardCoder-Python-34B-V1.0",
|
||||
|
||||
# language llms
|
||||
"NousResearch/Nous-Hermes-Llama2-13b",
|
||||
"Austism/chronos-hermes-13b",
|
||||
"upstage/SOLAR-0-70b-16bit",
|
||||
"WizardLM/WizardLM-70B-V1.0",
|
||||
|
||||
] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
|
||||
] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
|
||||
|
||||
|
||||
baseten_models: List = ["qvv0xeq", "q841o8w", "31dxrj3"] # FALCON 7B # WizardLM # Mosaic ML
|
||||
baseten_models: List = [
|
||||
"qvv0xeq",
|
||||
"q841o8w",
|
||||
"31dxrj3",
|
||||
] # FALCON 7B # WizardLM # Mosaic ML
|
||||
|
||||
petals_models = [
|
||||
"petals-team/StableBeluga2",
|
||||
]
|
||||
|
||||
ollama_models = [
|
||||
"llama2"
|
||||
]
|
||||
ollama_models = ["llama2"]
|
||||
|
||||
maritalk_models = [
|
||||
"maritalk"
|
||||
]
|
||||
maritalk_models = ["maritalk"]
|
||||
|
||||
model_list = (
|
||||
open_ai_chat_completion_models
|
||||
|
@ -308,7 +382,7 @@ provider_list: List = [
|
|||
"anyscale",
|
||||
"mistral",
|
||||
"maritalk",
|
||||
"custom", # custom apis
|
||||
"custom", # custom apis
|
||||
]
|
||||
|
||||
models_by_provider: dict = {
|
||||
|
@ -327,28 +401,28 @@ models_by_provider: dict = {
|
|||
"ollama": ollama_models,
|
||||
"deepinfra": deepinfra_models,
|
||||
"perplexity": perplexity_models,
|
||||
"maritalk": maritalk_models
|
||||
"maritalk": maritalk_models,
|
||||
}
|
||||
|
||||
# mapping for those models which have larger equivalents
|
||||
# mapping for those models which have larger equivalents
|
||||
longer_context_model_fallback_dict: dict = {
|
||||
# openai chat completion models
|
||||
"gpt-3.5-turbo": "gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301",
|
||||
"gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4": "gpt-4-32k",
|
||||
"gpt-4-0314": "gpt-4-32k-0314",
|
||||
"gpt-4-0613": "gpt-4-32k-0613",
|
||||
# anthropic
|
||||
"claude-instant-1": "claude-2",
|
||||
"gpt-3.5-turbo": "gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301",
|
||||
"gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4": "gpt-4-32k",
|
||||
"gpt-4-0314": "gpt-4-32k-0314",
|
||||
"gpt-4-0613": "gpt-4-32k-0613",
|
||||
# anthropic
|
||||
"claude-instant-1": "claude-2",
|
||||
"claude-instant-1.2": "claude-2",
|
||||
# vertexai
|
||||
"chat-bison": "chat-bison-32k",
|
||||
"chat-bison@001": "chat-bison-32k",
|
||||
"codechat-bison": "codechat-bison-32k",
|
||||
"codechat-bison": "codechat-bison-32k",
|
||||
"codechat-bison@001": "codechat-bison-32k",
|
||||
# openrouter
|
||||
"openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k",
|
||||
# openrouter
|
||||
"openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k",
|
||||
"openrouter/anthropic/claude-instant-v1": "openrouter/anthropic/claude-2",
|
||||
}
|
||||
|
||||
|
@ -357,20 +431,23 @@ open_ai_embedding_models: List = ["text-embedding-ada-002"]
|
|||
cohere_embedding_models: List = [
|
||||
"embed-english-v3.0",
|
||||
"embed-english-light-v3.0",
|
||||
"embed-multilingual-v3.0",
|
||||
"embed-english-v2.0",
|
||||
"embed-english-light-v2.0",
|
||||
"embed-multilingual-v2.0",
|
||||
"embed-multilingual-v3.0",
|
||||
"embed-english-v2.0",
|
||||
"embed-english-light-v2.0",
|
||||
"embed-multilingual-v2.0",
|
||||
]
|
||||
bedrock_embedding_models: List = [
|
||||
"amazon.titan-embed-text-v1",
|
||||
"cohere.embed-english-v3",
|
||||
"cohere.embed-multilingual-v3",
|
||||
]
|
||||
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]
|
||||
|
||||
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
|
||||
all_embedding_models = (
|
||||
open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
|
||||
)
|
||||
|
||||
####### IMAGE GENERATION MODELS ###################
|
||||
openai_image_generation_models = [
|
||||
"dall-e-2",
|
||||
"dall-e-3"
|
||||
]
|
||||
openai_image_generation_models = ["dall-e-2", "dall-e-3"]
|
||||
|
||||
|
||||
from .timeout import timeout
|
||||
|
@ -394,11 +471,11 @@ from .utils import (
|
|||
get_llm_provider,
|
||||
completion_with_config,
|
||||
register_model,
|
||||
encode,
|
||||
decode,
|
||||
encode,
|
||||
decode,
|
||||
_calculate_retry_after,
|
||||
_should_retry,
|
||||
get_secret
|
||||
get_secret,
|
||||
)
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
|
@ -415,7 +492,13 @@ from .llms.vertex_ai import VertexAIConfig
|
|||
from .llms.sagemaker import SagemakerConfig
|
||||
from .llms.ollama import OllamaConfig
|
||||
from .llms.maritalk import MaritTalkConfig
|
||||
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig, AmazonLlamaConfig
|
||||
from .llms.bedrock import (
|
||||
AmazonTitanConfig,
|
||||
AmazonAI21Config,
|
||||
AmazonAnthropicConfig,
|
||||
AmazonCohereConfig,
|
||||
AmazonLlamaConfig,
|
||||
)
|
||||
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
|
||||
from .llms.azure import AzureOpenAIConfig
|
||||
from .main import * # type: ignore
|
||||
|
@ -429,13 +512,13 @@ from .exceptions import (
|
|||
ServiceUnavailableError,
|
||||
OpenAIError,
|
||||
ContextWindowExceededError,
|
||||
BudgetExceededError,
|
||||
BudgetExceededError,
|
||||
APIError,
|
||||
Timeout,
|
||||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
UnprocessableEntityError
|
||||
APIResponseValidationError,
|
||||
UnprocessableEntityError,
|
||||
)
|
||||
from .budget_manager import BudgetManager
|
||||
from .proxy.proxy_cli import run_server
|
||||
from .router import Router
|
||||
from .router import Router
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
set_verbose = False
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
if set_verbose:
|
||||
print(print_statement) # noqa
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -13,6 +13,7 @@ import inspect
|
|||
import redis, litellm
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
def _get_redis_kwargs():
|
||||
arg_spec = inspect.getfullargspec(redis.Redis)
|
||||
|
||||
|
@ -23,32 +24,26 @@ def _get_redis_kwargs():
|
|||
"retry",
|
||||
}
|
||||
|
||||
|
||||
include_args = [
|
||||
"url"
|
||||
]
|
||||
include_args = ["url"]
|
||||
|
||||
available_args = [
|
||||
x for x in arg_spec.args if x not in exclude_args
|
||||
] + include_args
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_env_kwarg_mapping():
|
||||
PREFIX = "REDIS_"
|
||||
|
||||
return {
|
||||
f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()
|
||||
}
|
||||
return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
|
||||
|
||||
|
||||
def _redis_kwargs_from_environment():
|
||||
mapping = _get_redis_env_kwarg_mapping()
|
||||
|
||||
return_dict = {}
|
||||
return_dict = {}
|
||||
for k, v in mapping.items():
|
||||
value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
|
||||
if value is not None:
|
||||
value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
|
||||
if value is not None:
|
||||
return_dict[v] = value
|
||||
return return_dict
|
||||
|
||||
|
@ -56,21 +51,26 @@ def _redis_kwargs_from_environment():
|
|||
def get_redis_url_from_environment():
|
||||
if "REDIS_URL" in os.environ:
|
||||
return os.environ["REDIS_URL"]
|
||||
|
||||
|
||||
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
|
||||
raise ValueError("Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis.")
|
||||
raise ValueError(
|
||||
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
|
||||
)
|
||||
|
||||
if "REDIS_PASSWORD" in os.environ:
|
||||
redis_password = f":{os.environ['REDIS_PASSWORD']}@"
|
||||
else:
|
||||
redis_password = ""
|
||||
|
||||
return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
|
||||
return (
|
||||
f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
|
||||
)
|
||||
|
||||
|
||||
def get_redis_client(**env_overrides):
|
||||
### check if "os.environ/<key-name>" passed in
|
||||
for k, v in env_overrides.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
for k, v in env_overrides.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
v = v.replace("os.environ/", "")
|
||||
value = litellm.get_secret(v)
|
||||
env_overrides[k] = value
|
||||
|
@ -80,14 +80,14 @@ def get_redis_client(**env_overrides):
|
|||
**env_overrides,
|
||||
}
|
||||
|
||||
if "url" in redis_kwargs and redis_kwargs['url'] is not None:
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
redis_kwargs.pop("host", None)
|
||||
redis_kwargs.pop("port", None)
|
||||
redis_kwargs.pop("db", None)
|
||||
redis_kwargs.pop("password", None)
|
||||
|
||||
|
||||
return redis.Redis.from_url(**redis_kwargs)
|
||||
elif "host" not in redis_kwargs or redis_kwargs['host'] is None:
|
||||
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
||||
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||
return redis.Redis(**redis_kwargs)
|
||||
return redis.Redis(**redis_kwargs)
|
||||
|
|
|
@ -1,119 +1,166 @@
|
|||
import os, json, time
|
||||
import litellm
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse
|
||||
import requests, threading
|
||||
from typing import Optional, Union, Literal
|
||||
|
||||
|
||||
class BudgetManager:
|
||||
def __init__(self, project_name: str, client_type: str = "local", api_base: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
project_name: str,
|
||||
client_type: str = "local",
|
||||
api_base: Optional[str] = None,
|
||||
):
|
||||
self.client_type = client_type
|
||||
self.project_name = project_name
|
||||
self.api_base = api_base or "https://api.litellm.ai"
|
||||
## load the data or init the initial dictionaries
|
||||
self.load_data()
|
||||
|
||||
self.load_data()
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
if litellm.set_verbose:
|
||||
import logging
|
||||
|
||||
logging.info(print_statement)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def load_data(self):
|
||||
if self.client_type == "local":
|
||||
# Check if user dict file exists
|
||||
if os.path.isfile("user_cost.json"):
|
||||
# Load the user dict
|
||||
with open("user_cost.json", 'r') as json_file:
|
||||
with open("user_cost.json", "r") as json_file:
|
||||
self.user_dict = json.load(json_file)
|
||||
else:
|
||||
self.print_verbose("User Dictionary not found!")
|
||||
self.user_dict = {}
|
||||
self.user_dict = {}
|
||||
self.print_verbose(f"user dict from local: {self.user_dict}")
|
||||
elif self.client_type == "hosted":
|
||||
# Load the user_dict from hosted db
|
||||
url = self.api_base + "/get_budget"
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
data = {
|
||||
'project_name' : self.project_name
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {"project_name": self.project_name}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response = response.json()
|
||||
if response["status"] == "error":
|
||||
self.user_dict = {} # assume this means the user dict hasn't been stored yet
|
||||
self.user_dict = (
|
||||
{}
|
||||
) # assume this means the user dict hasn't been stored yet
|
||||
else:
|
||||
self.user_dict = response["data"]
|
||||
|
||||
def create_budget(self, total_budget: float, user: str, duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None, created_at: float = time.time()):
|
||||
def create_budget(
|
||||
self,
|
||||
total_budget: float,
|
||||
user: str,
|
||||
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
|
||||
created_at: float = time.time(),
|
||||
):
|
||||
self.user_dict[user] = {"total_budget": total_budget}
|
||||
if duration is None:
|
||||
return self.user_dict[user]
|
||||
|
||||
if duration == 'daily':
|
||||
|
||||
if duration == "daily":
|
||||
duration_in_days = 1
|
||||
elif duration == 'weekly':
|
||||
elif duration == "weekly":
|
||||
duration_in_days = 7
|
||||
elif duration == 'monthly':
|
||||
elif duration == "monthly":
|
||||
duration_in_days = 28
|
||||
elif duration == 'yearly':
|
||||
elif duration == "yearly":
|
||||
duration_in_days = 365
|
||||
else:
|
||||
raise ValueError("""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""")
|
||||
self.user_dict[user] = {"total_budget": total_budget, "duration": duration_in_days, "created_at": created_at, "last_updated_at": created_at}
|
||||
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
||||
raise ValueError(
|
||||
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
|
||||
)
|
||||
self.user_dict[user] = {
|
||||
"total_budget": total_budget,
|
||||
"duration": duration_in_days,
|
||||
"created_at": created_at,
|
||||
"last_updated_at": created_at,
|
||||
}
|
||||
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
||||
return self.user_dict[user]
|
||||
|
||||
|
||||
def projected_cost(self, model: str, messages: list, user: str):
|
||||
text = "".join(message["content"] for message in messages)
|
||||
prompt_tokens = litellm.token_counter(model=model, text=text)
|
||||
prompt_cost, _ = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=0)
|
||||
prompt_cost, _ = litellm.cost_per_token(
|
||||
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
|
||||
)
|
||||
current_cost = self.user_dict[user].get("current_cost", 0)
|
||||
projected_cost = prompt_cost + current_cost
|
||||
return projected_cost
|
||||
|
||||
|
||||
def get_total_budget(self, user: str):
|
||||
return self.user_dict[user]["total_budget"]
|
||||
|
||||
def update_cost(self, user: str, completion_obj: Optional[ModelResponse] = None, model: Optional[str] = None, input_text: Optional[str] = None, output_text: Optional[str] = None):
|
||||
if model and input_text and output_text:
|
||||
prompt_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": input_text}])
|
||||
completion_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": output_text}])
|
||||
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
||||
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||
elif completion_obj:
|
||||
cost = litellm.completion_cost(completion_response=completion_obj)
|
||||
model = completion_obj['model'] # if this throws an error try, model = completion_obj['model']
|
||||
else:
|
||||
raise ValueError("Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager")
|
||||
|
||||
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get("current_cost", 0)
|
||||
def update_cost(
|
||||
self,
|
||||
user: str,
|
||||
completion_obj: Optional[ModelResponse] = None,
|
||||
model: Optional[str] = None,
|
||||
input_text: Optional[str] = None,
|
||||
output_text: Optional[str] = None,
|
||||
):
|
||||
if model and input_text and output_text:
|
||||
prompt_tokens = litellm.token_counter(
|
||||
model=model, messages=[{"role": "user", "content": input_text}]
|
||||
)
|
||||
completion_tokens = litellm.token_counter(
|
||||
model=model, messages=[{"role": "user", "content": output_text}]
|
||||
)
|
||||
(
|
||||
prompt_tokens_cost_usd_dollar,
|
||||
completion_tokens_cost_usd_dollar,
|
||||
) = litellm.cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||
elif completion_obj:
|
||||
cost = litellm.completion_cost(completion_response=completion_obj)
|
||||
model = completion_obj[
|
||||
"model"
|
||||
] # if this throws an error try, model = completion_obj['model']
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
|
||||
)
|
||||
|
||||
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
|
||||
"current_cost", 0
|
||||
)
|
||||
if "model_cost" in self.user_dict[user]:
|
||||
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user]["model_cost"].get(model, 0)
|
||||
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
|
||||
"model_cost"
|
||||
].get(model, 0)
|
||||
else:
|
||||
self.user_dict[user]["model_cost"] = {model: cost}
|
||||
|
||||
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
||||
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
||||
return {"user": self.user_dict[user]}
|
||||
|
||||
|
||||
|
||||
def get_current_cost(self, user):
|
||||
return self.user_dict[user].get("current_cost", 0)
|
||||
|
||||
|
||||
def get_model_cost(self, user):
|
||||
return self.user_dict[user].get("model_cost", 0)
|
||||
|
||||
|
||||
def is_valid_user(self, user: str) -> bool:
|
||||
return user in self.user_dict
|
||||
|
||||
|
||||
def get_users(self):
|
||||
return list(self.user_dict.keys())
|
||||
|
||||
|
||||
def reset_cost(self, user):
|
||||
self.user_dict[user]["current_cost"] = 0
|
||||
self.user_dict[user]["model_cost"] = {}
|
||||
return {"user": self.user_dict[user]}
|
||||
|
||||
|
||||
def reset_on_duration(self, user: str):
|
||||
# Get current and creation time
|
||||
last_updated_at = self.user_dict[user]["last_updated_at"]
|
||||
|
@ -121,38 +168,39 @@ class BudgetManager:
|
|||
|
||||
# Convert duration from days to seconds
|
||||
duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60
|
||||
|
||||
|
||||
# Check if duration has elapsed
|
||||
if current_time - last_updated_at >= duration_in_seconds:
|
||||
# Reset cost if duration has elapsed and update the creation time
|
||||
self.reset_cost(user)
|
||||
self.user_dict[user]["last_updated_at"] = current_time
|
||||
self._save_data_thread() # Save the data
|
||||
|
||||
|
||||
def update_budget_all_users(self):
|
||||
for user in self.get_users():
|
||||
if "duration" in self.user_dict[user]:
|
||||
self.reset_on_duration(user)
|
||||
|
||||
def _save_data_thread(self):
|
||||
thread = threading.Thread(target=self.save_data) # [Non-Blocking]: saves data without blocking execution
|
||||
thread = threading.Thread(
|
||||
target=self.save_data
|
||||
) # [Non-Blocking]: saves data without blocking execution
|
||||
thread.start()
|
||||
|
||||
def save_data(self):
|
||||
if self.client_type == "local":
|
||||
import json
|
||||
|
||||
# save the user dict
|
||||
with open("user_cost.json", 'w') as json_file:
|
||||
json.dump(self.user_dict, json_file, indent=4) # Indent for pretty formatting
|
||||
import json
|
||||
|
||||
# save the user dict
|
||||
with open("user_cost.json", "w") as json_file:
|
||||
json.dump(
|
||||
self.user_dict, json_file, indent=4
|
||||
) # Indent for pretty formatting
|
||||
return {"status": "success"}
|
||||
elif self.client_type == "hosted":
|
||||
url = self.api_base + "/set_budget"
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
data = {
|
||||
'project_name' : self.project_name,
|
||||
"user_dict": self.user_dict
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {"project_name": self.project_name, "user_dict": self.user_dict}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response = response.json()
|
||||
return response
|
||||
return response
|
||||
|
|
|
@ -12,13 +12,15 @@ import time, logging
|
|||
import json, traceback, ast
|
||||
from typing import Optional, Literal, List
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class BaseCache:
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
@ -45,13 +47,13 @@ class InMemoryCache(BaseCache):
|
|||
self.cache_dict.pop(key, None)
|
||||
return None
|
||||
original_cached_response = self.cache_dict[key]
|
||||
try:
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except:
|
||||
except:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
|
@ -60,17 +62,18 @@ class InMemoryCache(BaseCache):
|
|||
class RedisCache(BaseCache):
|
||||
def __init__(self, host=None, port=None, password=None, **kwargs):
|
||||
import redis
|
||||
|
||||
# if users don't provider one, use the default litellm cache
|
||||
from ._redis import get_redis_client
|
||||
|
||||
redis_kwargs = {}
|
||||
if host is not None:
|
||||
if host is not None:
|
||||
redis_kwargs["host"] = host
|
||||
if port is not None:
|
||||
redis_kwargs["port"] = port
|
||||
if password is not None:
|
||||
if password is not None:
|
||||
redis_kwargs["password"] = password
|
||||
|
||||
|
||||
redis_kwargs.update(kwargs)
|
||||
|
||||
self.redis_client = get_redis_client(**redis_kwargs)
|
||||
|
@ -88,13 +91,19 @@ class RedisCache(BaseCache):
|
|||
try:
|
||||
print_verbose(f"Get Redis Cache: key: {key}")
|
||||
cached_response = self.redis_client.get(key)
|
||||
print_verbose(f"Got Redis Cache: key: {key}, cached_response {cached_response}")
|
||||
print_verbose(
|
||||
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
|
||||
)
|
||||
if cached_response != None:
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_response.decode("utf-8") # Convert bytes to string
|
||||
try:
|
||||
cached_response = json.loads(cached_response) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = cached_response.decode(
|
||||
"utf-8"
|
||||
) # Convert bytes to string
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
except Exception as e:
|
||||
|
@ -105,34 +114,40 @@ class RedisCache(BaseCache):
|
|||
def flush_cache(self):
|
||||
self.redis_client.flushall()
|
||||
|
||||
class DualCache(BaseCache):
|
||||
|
||||
class DualCache(BaseCache):
|
||||
"""
|
||||
This updates both Redis and an in-memory cache simultaneously.
|
||||
When data is updated or inserted, it is written to both the in-memory cache + Redis.
|
||||
This updates both Redis and an in-memory cache simultaneously.
|
||||
When data is updated or inserted, it is written to both the in-memory cache + Redis.
|
||||
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
|
||||
"""
|
||||
def __init__(self, in_memory_cache: Optional[InMemoryCache] =None, redis_cache: Optional[RedisCache] =None) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_memory_cache: Optional[InMemoryCache] = None,
|
||||
redis_cache: Optional[RedisCache] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# If in_memory_cache is not provided, use the default InMemoryCache
|
||||
self.in_memory_cache = in_memory_cache or InMemoryCache()
|
||||
# If redis_cache is not provided, use the default RedisCache
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
# Update both Redis and in-memory cache
|
||||
try:
|
||||
try:
|
||||
print_verbose(f"set cache: key: {key}; value: {value}")
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
print_verbose(e)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
try:
|
||||
print_verbose(f"get cache: cache key: {key}")
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
|
@ -141,7 +156,7 @@ class DualCache(BaseCache):
|
|||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if self.redis_cache is not None:
|
||||
if self.redis_cache is not None:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = self.redis_cache.get_cache(key, **kwargs)
|
||||
|
||||
|
@ -153,25 +168,28 @@ class DualCache(BaseCache):
|
|||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def flush_cache(self):
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.flush_cache()
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.flush_cache()
|
||||
|
||||
|
||||
#### LiteLLM.Completion / Embedding Cache ####
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Literal["local", "redis"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||
**kwargs
|
||||
self,
|
||||
type: Optional[Literal["local", "redis"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[
|
||||
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
|
||||
] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
@ -200,7 +218,7 @@ class Cache:
|
|||
litellm.success_callback.append("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append("cache")
|
||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||
|
||||
def get_cache_key(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -215,18 +233,37 @@ class Cache:
|
|||
"""
|
||||
cache_key = ""
|
||||
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
|
||||
|
||||
|
||||
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
|
||||
print_verbose(f"\nReturning preset cache key: {cache_key}")
|
||||
return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
|
||||
|
||||
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
|
||||
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]
|
||||
embedding_only_kwargs = ["input", "encoding_format"] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
|
||||
|
||||
completion_kwargs = [
|
||||
"model",
|
||||
"messages",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
embedding_only_kwargs = [
|
||||
"input",
|
||||
"encoding_format",
|
||||
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
|
||||
|
||||
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
|
||||
combined_kwargs = completion_kwargs + embedding_only_kwargs
|
||||
combined_kwargs = completion_kwargs + embedding_only_kwargs
|
||||
for param in combined_kwargs:
|
||||
# ignore litellm params here
|
||||
if param in kwargs:
|
||||
|
@ -241,8 +278,8 @@ class Cache:
|
|||
model_group = metadata.get("model_group", None)
|
||||
caching_groups = metadata.get("caching_groups", None)
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
caching_group = group
|
||||
break
|
||||
if litellm_params is not None:
|
||||
|
@ -251,23 +288,34 @@ class Cache:
|
|||
model_group = metadata.get("model_group", None)
|
||||
caching_groups = metadata.get("caching_groups", None)
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
caching_group = group
|
||||
break
|
||||
param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
||||
param_value = (
|
||||
caching_group or model_group or kwargs[param]
|
||||
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
||||
else:
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key+= f"{str(param)}: {str(param_value)}"
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
print_verbose(f"\nCreated cache key: {cache_key}")
|
||||
return cache_key
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
chunk_size = 5 # Adjust the chunk size as needed
|
||||
for i in range(0, len(content), chunk_size):
|
||||
yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i + chunk_size]}}]}
|
||||
yield {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": content[i : i + chunk_size],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
time.sleep(0.02)
|
||||
|
||||
def get_cache(self, *args, **kwargs):
|
||||
|
@ -319,4 +367,4 @@ class Cache:
|
|||
pass
|
||||
|
||||
async def _async_add_cache(self, result, *args, **kwargs):
|
||||
self.add_cache(result, *args, **kwargs)
|
||||
self.add_cache(result, *args, **kwargs)
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
# from .main import *
|
||||
# from .server_utils import *
|
||||
# from .server_utils import *
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
# llm_model_list: Optional[list] = None
|
||||
# server_settings: Optional[dict] = None
|
||||
|
||||
# set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
||||
# set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
||||
|
||||
# if "CONFIG_FILE_PATH" in os.environ:
|
||||
# llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
|
||||
|
@ -44,7 +44,7 @@
|
|||
# @router.get("/models") # if project requires model list
|
||||
# def model_list():
|
||||
# all_models = litellm.utils.get_valid_models()
|
||||
# if llm_model_list:
|
||||
# if llm_model_list:
|
||||
# all_models += llm_model_list
|
||||
# return dict(
|
||||
# data=[
|
||||
|
@ -79,8 +79,8 @@
|
|||
# @router.post("/v1/embeddings")
|
||||
# @router.post("/embeddings")
|
||||
# async def embedding(request: Request):
|
||||
# try:
|
||||
# data = await request.json()
|
||||
# try:
|
||||
# data = await request.json()
|
||||
# # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
|
||||
# if os.getenv("AUTH_STRATEGY", None) == "DYNAMIC" and "authorization" in request.headers: # if users pass LLM api keys as part of header
|
||||
# api_key = request.headers.get("authorization")
|
||||
|
@ -106,13 +106,13 @@
|
|||
# data = await request.json()
|
||||
# server_model = server_settings.get("completion_model", None) if server_settings else None
|
||||
# data["model"] = server_model or model or data["model"]
|
||||
# ## CHECK KEYS ##
|
||||
# ## CHECK KEYS ##
|
||||
# # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
|
||||
# # env_validation = litellm.validate_environment(model=data["model"])
|
||||
# # if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
|
||||
# # if "authorization" in request.headers:
|
||||
# # api_key = request.headers.get("authorization")
|
||||
# # elif "api-key" in request.headers:
|
||||
# # elif "api-key" in request.headers:
|
||||
# # api_key = request.headers.get("api-key")
|
||||
# # print(f"api_key in headers: {api_key}")
|
||||
# # if " " in api_key:
|
||||
|
@ -122,11 +122,11 @@
|
|||
# # api_key = api_key
|
||||
# # data["api_key"] = api_key
|
||||
# # print(f"api_key in data: {api_key}")
|
||||
# ## CHECK CONFIG ##
|
||||
# ## CHECK CONFIG ##
|
||||
# if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]:
|
||||
# for m in llm_model_list:
|
||||
# if data["model"] == m["model_name"]:
|
||||
# for key, value in m["litellm_params"].items():
|
||||
# for m in llm_model_list:
|
||||
# if data["model"] == m["model_name"]:
|
||||
# for key, value in m["litellm_params"].items():
|
||||
# data[key] = value
|
||||
# break
|
||||
# response = litellm.completion(
|
||||
|
@ -145,21 +145,21 @@
|
|||
# @router.post("/router/completions")
|
||||
# async def router_completion(request: Request):
|
||||
# global llm_router
|
||||
# try:
|
||||
# try:
|
||||
# data = await request.json()
|
||||
# if "model_list" in data:
|
||||
# if "model_list" in data:
|
||||
# llm_router = litellm.Router(model_list=data.pop("model_list"))
|
||||
# if llm_router is None:
|
||||
# if llm_router is None:
|
||||
# raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
|
||||
|
||||
|
||||
# # openai.ChatCompletion.create replacement
|
||||
# response = await llm_router.acompletion(model="gpt-3.5-turbo",
|
||||
# response = await llm_router.acompletion(model="gpt-3.5-turbo",
|
||||
# messages=[{"role": "user", "content": "Hey, how's it going?"}])
|
||||
|
||||
# if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
# return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
# return response
|
||||
# except Exception as e:
|
||||
# except Exception as e:
|
||||
# error_traceback = traceback.format_exc()
|
||||
# error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
# return {"error": error_msg}
|
||||
|
@ -167,11 +167,11 @@
|
|||
# @router.post("/router/embedding")
|
||||
# async def router_embedding(request: Request):
|
||||
# global llm_router
|
||||
# try:
|
||||
# try:
|
||||
# data = await request.json()
|
||||
# if "model_list" in data:
|
||||
# if "model_list" in data:
|
||||
# llm_router = litellm.Router(model_list=data.pop("model_list"))
|
||||
# if llm_router is None:
|
||||
# if llm_router is None:
|
||||
# raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
|
||||
|
||||
# response = await llm_router.aembedding(model="gpt-3.5-turbo", # type: ignore
|
||||
|
@ -180,7 +180,7 @@
|
|||
# if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
# return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
# return response
|
||||
# except Exception as e:
|
||||
# except Exception as e:
|
||||
# error_traceback = traceback.format_exc()
|
||||
# error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
# return {"error": error_msg}
|
||||
|
@ -190,4 +190,4 @@
|
|||
# return "LiteLLM: RUNNING"
|
||||
|
||||
|
||||
# app.include_router(router)
|
||||
# app.include_router(router)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# import dotenv
|
||||
# dotenv.load_dotenv() # load env variables
|
||||
|
||||
# def print_verbose(print_statement):
|
||||
# def print_verbose(print_statement):
|
||||
# pass
|
||||
|
||||
# def get_package_version(package_name):
|
||||
|
@ -27,32 +27,31 @@
|
|||
|
||||
# def set_callbacks():
|
||||
# ## LOGGING
|
||||
# if len(os.getenv("SET_VERBOSE", "")) > 0:
|
||||
# if os.getenv("SET_VERBOSE") == "True":
|
||||
# if len(os.getenv("SET_VERBOSE", "")) > 0:
|
||||
# if os.getenv("SET_VERBOSE") == "True":
|
||||
# litellm.set_verbose = True
|
||||
# print_verbose("\033[92mLiteLLM: Switched on verbose logging\033[0m")
|
||||
# else:
|
||||
# else:
|
||||
# litellm.set_verbose = False
|
||||
|
||||
# ### LANGFUSE
|
||||
# if (len(os.getenv("LANGFUSE_PUBLIC_KEY", "")) > 0 and len(os.getenv("LANGFUSE_SECRET_KEY", ""))) > 0 or len(os.getenv("LANGFUSE_HOST", "")) > 0:
|
||||
# litellm.success_callback = ["langfuse"]
|
||||
# litellm.success_callback = ["langfuse"]
|
||||
# print_verbose("\033[92mLiteLLM: Switched on Langfuse feature\033[0m")
|
||||
|
||||
# ## CACHING
|
||||
|
||||
# ## CACHING
|
||||
# ### REDIS
|
||||
# # if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0:
|
||||
# # if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0:
|
||||
# # print(f"redis host: {os.getenv('REDIS_HOST')}; redis port: {os.getenv('REDIS_PORT')}; password: {os.getenv('REDIS_PASSWORD')}")
|
||||
# # from litellm.caching import Cache
|
||||
# # litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
||||
# # print("\033[92mLiteLLM: Switched on Redis caching\033[0m")
|
||||
|
||||
|
||||
|
||||
# def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]='/app/config.yaml'):
|
||||
# config = {}
|
||||
# server_settings = {}
|
||||
# try:
|
||||
# server_settings = {}
|
||||
# try:
|
||||
# if os.path.exists(config_file_path): # type: ignore
|
||||
# with open(config_file_path, 'r') as file: # type: ignore
|
||||
# config = yaml.safe_load(file)
|
||||
|
@ -63,24 +62,24 @@
|
|||
|
||||
# ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')
|
||||
# server_settings = config.get("server_settings", None)
|
||||
# if server_settings:
|
||||
# if server_settings:
|
||||
# server_settings = server_settings
|
||||
|
||||
# ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
# litellm_settings = config.get('litellm_settings', None)
|
||||
# if litellm_settings:
|
||||
# for key, value in litellm_settings.items():
|
||||
# if litellm_settings:
|
||||
# for key, value in litellm_settings.items():
|
||||
# setattr(litellm, key, value)
|
||||
|
||||
# ## MODEL LIST
|
||||
# model_list = config.get('model_list', None)
|
||||
# if model_list:
|
||||
# if model_list:
|
||||
# router = litellm.Router(model_list=model_list)
|
||||
|
||||
|
||||
# ## ENVIRONMENT VARIABLES
|
||||
# environment_variables = config.get('environment_variables', None)
|
||||
# if environment_variables:
|
||||
# for key, value in environment_variables.items():
|
||||
# if environment_variables:
|
||||
# for key, value in environment_variables.items():
|
||||
# os.environ[key] = value
|
||||
|
||||
# return router, model_list, server_settings
|
||||
|
|
|
@ -16,11 +16,11 @@ from openai import (
|
|||
RateLimitError,
|
||||
APIStatusError,
|
||||
OpenAIError,
|
||||
APIError,
|
||||
APITimeoutError,
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
APITimeoutError,
|
||||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
UnprocessableEntityError
|
||||
UnprocessableEntityError,
|
||||
)
|
||||
import httpx
|
||||
|
||||
|
@ -32,11 +32,10 @@ class AuthenticationError(AuthenticationError): # type: ignore
|
|||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
super().__init__(
|
||||
self.message,
|
||||
response=response,
|
||||
body=None
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
# raise when invalid models passed, example gpt-8
|
||||
class NotFoundError(NotFoundError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
|
@ -45,9 +44,7 @@ class NotFoundError(NotFoundError): # type: ignore
|
|||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
super().__init__(
|
||||
self.message,
|
||||
response=response,
|
||||
body=None
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
|
@ -58,23 +55,21 @@ class BadRequestError(BadRequestError): # type: ignore
|
|||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
super().__init__(
|
||||
self.message,
|
||||
response=response,
|
||||
body=None
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
|
||||
|
||||
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
self.status_code = 422
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
super().__init__(
|
||||
self.message,
|
||||
response=response,
|
||||
body=None
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class Timeout(APITimeoutError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider):
|
||||
self.status_code = 408
|
||||
|
@ -86,6 +81,7 @@ class Timeout(APITimeoutError): # type: ignore
|
|||
request=request
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class RateLimitError(RateLimitError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
self.status_code = 429
|
||||
|
@ -93,11 +89,10 @@ class RateLimitError(RateLimitError): # type: ignore
|
|||
self.llm_provider = llm_provider
|
||||
self.modle = model
|
||||
super().__init__(
|
||||
self.message,
|
||||
response=response,
|
||||
body=None
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
|
||||
class ContextWindowExceededError(BadRequestError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
|
@ -106,12 +101,13 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
|||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
llm_provider=self.llm_provider, # type: ignore
|
||||
response=response
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
llm_provider=self.llm_provider, # type: ignore
|
||||
response=response,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class ServiceUnavailableError(APIStatusError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
self.status_code = 503
|
||||
|
@ -119,50 +115,42 @@ class ServiceUnavailableError(APIStatusError): # type: ignore
|
|||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
super().__init__(
|
||||
self.message,
|
||||
response=response,
|
||||
body=None
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
|
||||
class APIError(APIError): # type: ignore
|
||||
def __init__(self, status_code, message, llm_provider, model, request: httpx.Request):
|
||||
self.status_code = status_code
|
||||
class APIError(APIError): # type: ignore
|
||||
def __init__(
|
||||
self, status_code, message, llm_provider, model, request: httpx.Request
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
super().__init__(
|
||||
self.message,
|
||||
request=request, # type: ignore
|
||||
body=None
|
||||
)
|
||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
class APIConnectionError(APIConnectionError): # type: ignore
|
||||
class APIConnectionError(APIConnectionError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, request: httpx.Request):
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.status_code = 500
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
request=request
|
||||
)
|
||||
super().__init__(message=self.message, request=request)
|
||||
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
class APIResponseValidationError(APIResponseValidationError): # type: ignore
|
||||
class APIResponseValidationError(APIResponseValidationError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model):
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
response = httpx.Response(status_code=500, request=request)
|
||||
super().__init__(
|
||||
response=response,
|
||||
body=None,
|
||||
message=message
|
||||
)
|
||||
super().__init__(response=response, body=None, message=message)
|
||||
|
||||
|
||||
class OpenAIError(OpenAIError): # type: ignore
|
||||
def __init__(self, original_exception):
|
||||
|
@ -176,6 +164,7 @@ class OpenAIError(OpenAIError): # type: ignore
|
|||
)
|
||||
self.llm_provider = "openai"
|
||||
|
||||
|
||||
class BudgetExceededError(Exception):
|
||||
def __init__(self, current_cost, max_budget):
|
||||
self.current_cost = current_cost
|
||||
|
@ -183,7 +172,8 @@ class BudgetExceededError(Exception):
|
|||
message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
|
||||
super().__init__(message)
|
||||
|
||||
## DEPRECATED ##
|
||||
|
||||
## DEPRECATED ##
|
||||
class InvalidRequestError(BadRequestError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider):
|
||||
self.status_code = 400
|
||||
|
|
|
@ -5,32 +5,33 @@ import requests
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from typing import Literal
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
pass
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
|
||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
#### ASYNC ####
|
||||
|
||||
#### ASYNC ####
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
|
@ -43,81 +44,87 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
|
|||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
#### CALL HOOKS - proxy only ####
|
||||
"""
|
||||
Control the modify incoming / outgoung data before calling the model
|
||||
"""
|
||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: Literal["completion", "embeddings"],
|
||||
):
|
||||
pass
|
||||
|
||||
async def async_post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||
):
|
||||
pass
|
||||
|
||||
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
||||
|
||||
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||
try:
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["log_event_type"] = "pre_api_call"
|
||||
callback_func(
|
||||
kwargs,
|
||||
)
|
||||
print_verbose(
|
||||
f"Custom Logger - model call details: {kwargs}"
|
||||
)
|
||||
except:
|
||||
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||
try:
|
||||
async def async_log_input_event(
|
||||
self, model, messages, kwargs, print_verbose, callback_func
|
||||
):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["log_event_type"] = "pre_api_call"
|
||||
await callback_func(
|
||||
kwargs,
|
||||
)
|
||||
print_verbose(
|
||||
f"Custom Logger - model call details: {kwargs}"
|
||||
)
|
||||
except:
|
||||
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
|
||||
def log_event(
|
||||
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
|
||||
):
|
||||
# Method definition
|
||||
try:
|
||||
kwargs["log_event_type"] = "post_api_call"
|
||||
callback_func(
|
||||
kwargs, # kwargs to func
|
||||
kwargs, # kwargs to func
|
||||
response_obj,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
print_verbose(
|
||||
f"Custom Logger - final response object: {response_obj}"
|
||||
)
|
||||
print_verbose(f"Custom Logger - final response object: {response_obj}")
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
||||
async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
|
||||
|
||||
async def async_log_event(
|
||||
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
|
||||
):
|
||||
# Method definition
|
||||
try:
|
||||
kwargs["log_event_type"] = "post_api_call"
|
||||
await callback_func(
|
||||
kwargs, # kwargs to func
|
||||
kwargs, # kwargs to func
|
||||
response_obj,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
print_verbose(
|
||||
f"Custom Logger - final response object: {response_obj}"
|
||||
)
|
||||
print_verbose(f"Custom Logger - final response object: {response_obj}")
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -10,19 +10,28 @@ import datetime, subprocess, sys
|
|||
import litellm, uuid
|
||||
from litellm._logging import print_verbose
|
||||
|
||||
|
||||
class DyanmoDBLogger:
|
||||
# Class variables or attributes
|
||||
|
||||
def __init__(self):
|
||||
# Instance variables
|
||||
import boto3
|
||||
self.dynamodb = boto3.resource('dynamodb', region_name=os.environ["AWS_REGION_NAME"])
|
||||
|
||||
self.dynamodb = boto3.resource(
|
||||
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
|
||||
)
|
||||
if litellm.dynamodb_table_name is None:
|
||||
raise ValueError("LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`")
|
||||
raise ValueError(
|
||||
"LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`"
|
||||
)
|
||||
self.table_name = litellm.dynamodb_table_name
|
||||
|
||||
async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
async def _async_log_event(
|
||||
self, kwargs, response_obj, start_time, end_time, print_verbose
|
||||
):
|
||||
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
try:
|
||||
print_verbose(
|
||||
|
@ -32,7 +41,9 @@ class DyanmoDBLogger:
|
|||
# construct payload to send to DynamoDB
|
||||
# follows the same params as langfuse.py
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
messages = kwargs.get("messages")
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
call_type = kwargs.get("call_type", "litellm.completion")
|
||||
|
@ -51,7 +62,7 @@ class DyanmoDBLogger:
|
|||
"messages": messages,
|
||||
"response": response_obj,
|
||||
"usage": usage,
|
||||
"metadata": metadata
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
# Ensure everything in the payload is converted to str
|
||||
|
@ -62,9 +73,8 @@ class DyanmoDBLogger:
|
|||
# non blocking if it can't cast to a str
|
||||
pass
|
||||
|
||||
|
||||
print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
|
||||
|
||||
|
||||
# put data in dyanmo DB
|
||||
table = self.dynamodb.Table(self.table_name)
|
||||
# Assuming log_data is a dictionary with log information
|
||||
|
@ -79,4 +89,4 @@ class DyanmoDBLogger:
|
|||
except:
|
||||
traceback.print_exc()
|
||||
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -64,7 +64,9 @@ class LangFuseLogger:
|
|||
# end of processing langfuse ########################
|
||||
input = prompt
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
print(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}")
|
||||
print(
|
||||
f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}"
|
||||
)
|
||||
self._log_langfuse_v2(
|
||||
user_id,
|
||||
metadata,
|
||||
|
@ -171,7 +173,6 @@ class LangFuseLogger:
|
|||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
trace.generation(
|
||||
name=metadata.get("generation_name", "litellm-completion"),
|
||||
startTime=start_time,
|
||||
|
|
|
@ -8,25 +8,27 @@ from datetime import datetime
|
|||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
class LangsmithLogger:
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
|
||||
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
# Method definition
|
||||
# inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb
|
||||
metadata = {}
|
||||
if "litellm_params" in kwargs:
|
||||
metadata = kwargs["litellm_params"].get("metadata", {})
|
||||
# set project name and run_name for langsmith logging
|
||||
# set project name and run_name for langsmith logging
|
||||
# users can pass project_name and run name to litellm.completion()
|
||||
# Example: litellm.completion(model, messages, metadata={"project_name": "my-litellm-project", "run_name": "my-langsmith-run"})
|
||||
# if not set litellm will use default project_name = litellm-completion, run_name = LLMRun
|
||||
project_name = metadata.get("project_name", "litellm-completion")
|
||||
run_name = metadata.get("run_name", "LLMRun")
|
||||
print_verbose(f"Langsmith Logging - project_name: {project_name}, run_name {run_name}")
|
||||
print_verbose(
|
||||
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
|
||||
)
|
||||
try:
|
||||
print_verbose(
|
||||
f"Langsmith Logging - Enters logging function for model {kwargs}"
|
||||
|
@ -34,6 +36,7 @@ class LangsmithLogger:
|
|||
import requests
|
||||
import datetime
|
||||
from datetime import timezone
|
||||
|
||||
try:
|
||||
start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat()
|
||||
end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat()
|
||||
|
@ -45,7 +48,7 @@ class LangsmithLogger:
|
|||
new_kwargs = {}
|
||||
for key in kwargs:
|
||||
value = kwargs[key]
|
||||
if key == "start_time" or key =="end_time":
|
||||
if key == "start_time" or key == "end_time":
|
||||
pass
|
||||
elif type(value) != dict:
|
||||
new_kwargs[key] = value
|
||||
|
@ -54,18 +57,14 @@ class LangsmithLogger:
|
|||
"https://api.smith.langchain.com/runs",
|
||||
json={
|
||||
"name": run_name,
|
||||
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
|
||||
"inputs": {
|
||||
**new_kwargs
|
||||
},
|
||||
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
|
||||
"inputs": {**new_kwargs},
|
||||
"outputs": response_obj.json(),
|
||||
"session_name": project_name,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
},
|
||||
headers={
|
||||
"x-api-key": self.langsmith_api_key
|
||||
}
|
||||
headers={"x-api-key": self.langsmith_api_key},
|
||||
)
|
||||
print_verbose(
|
||||
f"Langsmith Layer Logging - final response object: {response_obj}"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import requests, traceback, json, os
|
||||
import types
|
||||
|
||||
|
||||
class LiteDebugger:
|
||||
user_email = None
|
||||
dashboard_url = None
|
||||
|
@ -12,9 +13,15 @@ class LiteDebugger:
|
|||
|
||||
def validate_environment(self, email):
|
||||
try:
|
||||
self.user_email = (email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL"))
|
||||
if self.user_email == None: # if users are trying to use_client=True but token not set
|
||||
raise ValueError("litellm.use_client = True but no token or email passed. Please set it in litellm.token")
|
||||
self.user_email = (
|
||||
email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")
|
||||
)
|
||||
if (
|
||||
self.user_email == None
|
||||
): # if users are trying to use_client=True but token not set
|
||||
raise ValueError(
|
||||
"litellm.use_client = True but no token or email passed. Please set it in litellm.token"
|
||||
)
|
||||
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
|
||||
try:
|
||||
print(
|
||||
|
@ -42,7 +49,9 @@ class LiteDebugger:
|
|||
litellm_params,
|
||||
optional_params,
|
||||
):
|
||||
print_verbose(f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}")
|
||||
print_verbose(
|
||||
f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}"
|
||||
)
|
||||
try:
|
||||
print_verbose(
|
||||
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
|
||||
|
@ -56,7 +65,11 @@ class LiteDebugger:
|
|||
updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
|
||||
|
||||
if call_type == "embedding":
|
||||
for message in messages: # assuming the input is a list as required by the embedding function
|
||||
for (
|
||||
message
|
||||
) in (
|
||||
messages
|
||||
): # assuming the input is a list as required by the embedding function
|
||||
litellm_data_obj = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": message}],
|
||||
|
@ -79,7 +92,9 @@ class LiteDebugger:
|
|||
elif call_type == "completion":
|
||||
litellm_data_obj = {
|
||||
"model": model,
|
||||
"messages": messages if isinstance(messages, list) else [{"role": "user", "content": messages}],
|
||||
"messages": messages
|
||||
if isinstance(messages, list)
|
||||
else [{"role": "user", "content": messages}],
|
||||
"end_user": end_user,
|
||||
"status": "initiated",
|
||||
"litellm_call_id": litellm_call_id,
|
||||
|
@ -95,20 +110,30 @@ class LiteDebugger:
|
|||
headers={"content-type": "application/json"},
|
||||
data=json.dumps(litellm_data_obj),
|
||||
)
|
||||
print_verbose(f"LiteDebugger: completion api response - {response.text}")
|
||||
print_verbose(
|
||||
f"LiteDebugger: completion api response - {response.text}"
|
||||
)
|
||||
except:
|
||||
print_verbose(
|
||||
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
def post_call_log_event(self, original_response, litellm_call_id, print_verbose, call_type, stream):
|
||||
print_verbose(f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}")
|
||||
def post_call_log_event(
|
||||
self, original_response, litellm_call_id, print_verbose, call_type, stream
|
||||
):
|
||||
print_verbose(
|
||||
f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}"
|
||||
)
|
||||
try:
|
||||
if call_type == "embedding":
|
||||
litellm_data_obj = {
|
||||
"status": "received",
|
||||
"additional_details": {"original_response": str(original_response["data"][0]["embedding"][:5])}, # don't store the entire vector
|
||||
"additional_details": {
|
||||
"original_response": str(
|
||||
original_response["data"][0]["embedding"][:5]
|
||||
)
|
||||
}, # don't store the entire vector
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"user_email": self.user_email,
|
||||
}
|
||||
|
@ -122,7 +147,11 @@ class LiteDebugger:
|
|||
elif call_type == "completion" and stream:
|
||||
litellm_data_obj = {
|
||||
"status": "received",
|
||||
"additional_details": {"original_response": "Streamed response" if isinstance(original_response, types.GeneratorType) else original_response},
|
||||
"additional_details": {
|
||||
"original_response": "Streamed response"
|
||||
if isinstance(original_response, types.GeneratorType)
|
||||
else original_response
|
||||
},
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"user_email": self.user_email,
|
||||
}
|
||||
|
@ -146,10 +175,12 @@ class LiteDebugger:
|
|||
end_time,
|
||||
litellm_call_id,
|
||||
print_verbose,
|
||||
call_type,
|
||||
stream = False
|
||||
call_type,
|
||||
stream=False,
|
||||
):
|
||||
print_verbose(f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}")
|
||||
print_verbose(
|
||||
f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}"
|
||||
)
|
||||
try:
|
||||
print_verbose(
|
||||
f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}"
|
||||
|
@ -186,7 +217,7 @@ class LiteDebugger:
|
|||
data=json.dumps(litellm_data_obj),
|
||||
)
|
||||
elif call_type == "completion" and stream == True:
|
||||
if len(response_obj["content"]) > 0: # don't log the empty strings
|
||||
if len(response_obj["content"]) > 0: # don't log the empty strings
|
||||
litellm_data_obj = {
|
||||
"response_time": response_time,
|
||||
"total_cost": total_cost,
|
||||
|
|
|
@ -18,19 +18,17 @@ class PromptLayerLogger:
|
|||
# Method definition
|
||||
try:
|
||||
new_kwargs = {}
|
||||
new_kwargs['model'] = kwargs['model']
|
||||
new_kwargs['messages'] = kwargs['messages']
|
||||
new_kwargs["model"] = kwargs["model"]
|
||||
new_kwargs["messages"] = kwargs["messages"]
|
||||
|
||||
# add kwargs["optional_params"] to new_kwargs
|
||||
for optional_param in kwargs["optional_params"]:
|
||||
new_kwargs[optional_param] = kwargs["optional_params"][optional_param]
|
||||
|
||||
|
||||
print_verbose(
|
||||
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
|
||||
)
|
||||
|
||||
|
||||
request_response = requests.post(
|
||||
"https://api.promptlayer.com/rest/track-request",
|
||||
json={
|
||||
|
@ -51,8 +49,8 @@ class PromptLayerLogger:
|
|||
f"Prompt Layer Logging: success - final response object: {request_response.text}"
|
||||
)
|
||||
response_json = request_response.json()
|
||||
if "success" not in request_response.json():
|
||||
raise Exception("Promptlayer did not successfully log the response!")
|
||||
if "success" not in request_response.json():
|
||||
raise Exception("Promptlayer did not successfully log the response!")
|
||||
|
||||
if "request_id" in response_json:
|
||||
print(kwargs["litellm_params"]["metadata"])
|
||||
|
@ -62,10 +60,12 @@ class PromptLayerLogger:
|
|||
json={
|
||||
"request_id": response_json["request_id"],
|
||||
"api_key": self.key,
|
||||
"metadata": kwargs["litellm_params"]["metadata"]
|
||||
"metadata": kwargs["litellm_params"]["metadata"],
|
||||
},
|
||||
)
|
||||
print_verbose(f"Prompt Layer Logging: success - metadata post response object: {response.text}")
|
||||
print_verbose(
|
||||
f"Prompt Layer Logging: success - metadata post response object: {response.text}"
|
||||
)
|
||||
|
||||
except:
|
||||
print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")
|
||||
|
|
|
@ -9,6 +9,7 @@ import traceback
|
|||
import datetime, subprocess, sys
|
||||
import litellm
|
||||
|
||||
|
||||
class Supabase:
|
||||
# Class variables or attributes
|
||||
supabase_table_name = "request_logs"
|
||||
|
|
|
@ -2,6 +2,7 @@ class TraceloopLogger:
|
|||
def __init__(self):
|
||||
from traceloop.sdk.tracing.tracing import TracerWrapper
|
||||
from traceloop.sdk import Traceloop
|
||||
|
||||
Traceloop.init(app_name="Litellm-Server", disable_batch=True)
|
||||
self.tracer_wrapper = TracerWrapper()
|
||||
|
||||
|
@ -29,15 +30,18 @@ class TraceloopLogger:
|
|||
)
|
||||
if "stop" in optional_params:
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_CHAT_STOP_SEQUENCES, optional_params.get("stop")
|
||||
SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
|
||||
optional_params.get("stop"),
|
||||
)
|
||||
if "frequency_penalty" in optional_params:
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_FREQUENCY_PENALTY, optional_params.get("frequency_penalty")
|
||||
SpanAttributes.LLM_FREQUENCY_PENALTY,
|
||||
optional_params.get("frequency_penalty"),
|
||||
)
|
||||
if "presence_penalty" in optional_params:
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_PRESENCE_PENALTY, optional_params.get("presence_penalty")
|
||||
SpanAttributes.LLM_PRESENCE_PENALTY,
|
||||
optional_params.get("presence_penalty"),
|
||||
)
|
||||
if "top_p" in optional_params:
|
||||
span.set_attribute(
|
||||
|
@ -45,7 +49,10 @@ class TraceloopLogger:
|
|||
)
|
||||
if "tools" in optional_params or "functions" in optional_params:
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_REQUEST_FUNCTIONS, optional_params.get("tools", optional_params.get("functions"))
|
||||
SpanAttributes.LLM_REQUEST_FUNCTIONS,
|
||||
optional_params.get(
|
||||
"tools", optional_params.get("functions")
|
||||
),
|
||||
)
|
||||
if "user" in optional_params:
|
||||
span.set_attribute(
|
||||
|
@ -53,7 +60,8 @@ class TraceloopLogger:
|
|||
)
|
||||
if "max_tokens" in optional_params:
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_REQUEST_MAX_TOKENS, kwargs.get("max_tokens")
|
||||
SpanAttributes.LLM_REQUEST_MAX_TOKENS,
|
||||
kwargs.get("max_tokens"),
|
||||
)
|
||||
if "temperature" in optional_params:
|
||||
span.set_attribute(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
imported_openAIResponse=True
|
||||
imported_openAIResponse = True
|
||||
try:
|
||||
import io
|
||||
import logging
|
||||
|
@ -12,15 +12,12 @@ try:
|
|||
else:
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
K = TypeVar("K", bound=str)
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class OpenAIResponse(Protocol[K, V]): # type: ignore
|
||||
class OpenAIResponse(Protocol[K, V]): # type: ignore
|
||||
# contains a (known) object attribute
|
||||
object: Literal["chat.completion", "edit", "text_completion"]
|
||||
|
||||
|
@ -30,7 +27,6 @@ try:
|
|||
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
class OpenAIRequestResponseResolver:
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -44,7 +40,9 @@ try:
|
|||
elif response["object"] == "text_completion":
|
||||
return self._resolve_completion(request, response, time_elapsed)
|
||||
elif response["object"] == "chat.completion":
|
||||
return self._resolve_chat_completion(request, response, time_elapsed)
|
||||
return self._resolve_chat_completion(
|
||||
request, response, time_elapsed
|
||||
)
|
||||
else:
|
||||
logger.info(f"Unknown OpenAI response object: {response['object']}")
|
||||
except Exception as e:
|
||||
|
@ -113,7 +111,8 @@ try:
|
|||
"""Resolves the request and response objects for `openai.Completion`."""
|
||||
request_str = f"\n\n**Prompt**: {request['prompt']}\n"
|
||||
choices = [
|
||||
f"\n\n**Completion**: {choice['text']}\n" for choice in response["choices"]
|
||||
f"\n\n**Completion**: {choice['text']}\n"
|
||||
for choice in response["choices"]
|
||||
]
|
||||
|
||||
return self._request_response_result_to_trace(
|
||||
|
@ -167,9 +166,9 @@ try:
|
|||
]
|
||||
trace = self.results_to_trace_tree(request, response, results, time_elapsed)
|
||||
return trace
|
||||
except:
|
||||
imported_openAIResponse=False
|
||||
|
||||
except:
|
||||
imported_openAIResponse = False
|
||||
|
||||
|
||||
#### What this does ####
|
||||
|
@ -182,29 +181,34 @@ from datetime import datetime
|
|||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
class WeightsBiasesLogger:
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
try:
|
||||
import wandb
|
||||
except:
|
||||
raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m")
|
||||
if imported_openAIResponse==False:
|
||||
raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m")
|
||||
raise Exception(
|
||||
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
|
||||
)
|
||||
if imported_openAIResponse == False:
|
||||
raise Exception(
|
||||
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
|
||||
)
|
||||
self.resolver = OpenAIRequestResponseResolver()
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
# Method definition
|
||||
import wandb
|
||||
|
||||
|
||||
try:
|
||||
print_verbose(
|
||||
f"W&B Logging - Enters logging function for model {kwargs}"
|
||||
)
|
||||
print_verbose(f"W&B Logging - Enters logging function for model {kwargs}")
|
||||
run = wandb.init()
|
||||
print_verbose(response_obj)
|
||||
|
||||
trace = self.resolver(kwargs, response_obj, (end_time-start_time).total_seconds())
|
||||
trace = self.resolver(
|
||||
kwargs, response_obj, (end_time - start_time).total_seconds()
|
||||
)
|
||||
|
||||
if trace is not None:
|
||||
run.log({"trace": trace})
|
||||
|
|
|
@ -7,76 +7,93 @@ from typing import Callable, Optional
|
|||
from litellm.utils import ModelResponse, Choices, Message
|
||||
import litellm
|
||||
|
||||
|
||||
class AI21Error(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.ai21.com/studio/v1/")
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.ai21.com/studio/v1/"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class AI21Config():
|
||||
|
||||
class AI21Config:
|
||||
"""
|
||||
Reference: https://docs.ai21.com/reference/j2-complete-ref
|
||||
|
||||
The class `AI21Config` provides configuration for the AI21's API interface. Below are the parameters:
|
||||
|
||||
- `numResults` (int32): Number of completions to sample and return. Optional, default is 1. If the temperature is greater than 0 (non-greedy decoding), a value greater than 1 can be meaningful.
|
||||
|
||||
|
||||
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
||||
|
||||
|
||||
- `minTokens` (int32): The minimum number of tokens to generate per result. Optional, default is 0. If `stopSequences` are given, they are ignored until `minTokens` are generated.
|
||||
|
||||
|
||||
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
||||
|
||||
|
||||
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
||||
|
||||
|
||||
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
||||
|
||||
|
||||
- `topKReturn` (int32): Range between 0 to 10, including both. Optional, default is 0. Specifies the top-K alternative tokens to return. A non-zero value includes the string representations and log-probabilities for each of the top-K alternatives at each position.
|
||||
|
||||
|
||||
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
||||
|
||||
|
||||
- `presencePenalty` (object): Placeholder for presence penalty object.
|
||||
|
||||
|
||||
- `countPenalty` (object): Placeholder for count penalty object.
|
||||
"""
|
||||
numResults: Optional[int]=None
|
||||
maxTokens: Optional[int]=None
|
||||
minTokens: Optional[int]=None
|
||||
temperature: Optional[float]=None
|
||||
topP: Optional[float]=None
|
||||
stopSequences: Optional[list]=None
|
||||
topKReturn: Optional[int]=None
|
||||
frequencePenalty: Optional[dict]=None
|
||||
presencePenalty: Optional[dict]=None
|
||||
countPenalty: Optional[dict]=None
|
||||
|
||||
def __init__(self,
|
||||
numResults: Optional[int]=None,
|
||||
maxTokens: Optional[int]=None,
|
||||
minTokens: Optional[int]=None,
|
||||
temperature: Optional[float]=None,
|
||||
topP: Optional[float]=None,
|
||||
stopSequences: Optional[list]=None,
|
||||
topKReturn: Optional[int]=None,
|
||||
frequencePenalty: Optional[dict]=None,
|
||||
presencePenalty: Optional[dict]=None,
|
||||
countPenalty: Optional[dict]=None) -> None:
|
||||
numResults: Optional[int] = None
|
||||
maxTokens: Optional[int] = None
|
||||
minTokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
stopSequences: Optional[list] = None
|
||||
topKReturn: Optional[int] = None
|
||||
frequencePenalty: Optional[dict] = None
|
||||
presencePenalty: Optional[dict] = None
|
||||
countPenalty: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
numResults: Optional[int] = None,
|
||||
maxTokens: Optional[int] = None,
|
||||
minTokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[float] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
topKReturn: Optional[int] = None,
|
||||
frequencePenalty: Optional[dict] = None,
|
||||
presencePenalty: Optional[dict] = None,
|
||||
countPenalty: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
|
@ -91,6 +108,7 @@ def validate_environment(api_key):
|
|||
}
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -110,20 +128,18 @@ def completion(
|
|||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
|
||||
|
||||
## Load Config
|
||||
config = litellm.AI21Config.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.AI21Config.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
|
@ -134,29 +150,26 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
api_base + model + "/complete", headers=headers, data=json.dumps(data)
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise AI21Error(
|
||||
status_code=response.status_code,
|
||||
message=response.text
|
||||
)
|
||||
raise AI21Error(status_code=response.status_code, message=response.text)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response.json()
|
||||
try:
|
||||
|
@ -164,18 +177,22 @@ def completion(
|
|||
for idx, item in enumerate(completion_response["completions"]):
|
||||
if len(item["data"]["text"]) > 0:
|
||||
message_obj = Message(content=item["data"]["text"])
|
||||
else:
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(finish_reason=item["finishReason"]["reason"], index=idx+1, message=message_obj)
|
||||
choice_obj = Choices(
|
||||
finish_reason=item["finishReason"]["reason"],
|
||||
index=idx + 1,
|
||||
message=message_obj,
|
||||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
raise AI21Error(message=traceback.format_exc(), status_code=response.status_code)
|
||||
raise AI21Error(
|
||||
message=traceback.format_exc(), status_code=response.status_code
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content"))
|
||||
)
|
||||
|
@ -189,6 +206,7 @@ def completion(
|
|||
}
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -8,17 +8,21 @@ import litellm
|
|||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
import httpx
|
||||
|
||||
|
||||
class AlephAlphaError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.aleph-alpha.com/complete")
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.aleph-alpha.com/complete"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class AlephAlphaConfig():
|
||||
|
||||
class AlephAlphaConfig:
|
||||
"""
|
||||
Reference: https://docs.aleph-alpha.com/api/complete/
|
||||
|
||||
|
@ -42,13 +46,13 @@ class AlephAlphaConfig():
|
|||
|
||||
- `repetition_penalties_include_prompt`, `repetition_penalties_include_completion`, `use_multiplicative_presence_penalty`,`use_multiplicative_frequency_penalty`,`use_multiplicative_sequence_penalty` (boolean, nullable; default value: false): Various settings that adjust how the repetition penalties are applied.
|
||||
|
||||
- `penalty_bias` (string, nullable): Text used in addition to the penalized tokens for repetition penalties.
|
||||
- `penalty_bias` (string, nullable): Text used in addition to the penalized tokens for repetition penalties.
|
||||
|
||||
- `penalty_exceptions` (string[], nullable): Strings that may be generated without penalty.
|
||||
|
||||
- `penalty_exceptions_include_stop_sequences` (boolean, nullable; default value: true): Include all stop_sequences in penalty_exceptions.
|
||||
|
||||
- `best_of` (integer, nullable; default value: 1): The number of completions will be generated on the server side.
|
||||
- `best_of` (integer, nullable; default value: 1): The number of completions will be generated on the server side.
|
||||
|
||||
- `n` (integer, nullable; default value: 1): The number of completions to return.
|
||||
|
||||
|
@ -68,87 +72,101 @@ class AlephAlphaConfig():
|
|||
|
||||
- `completion_bias_inclusion_first_token_only`, `completion_bias_exclusion_first_token_only` (boolean; default value: false): Consider only the first token for the completion_bias_inclusion/exclusion.
|
||||
|
||||
- `contextual_control_threshold` (number, nullable): Control over how similar tokens are controlled.
|
||||
- `contextual_control_threshold` (number, nullable): Control over how similar tokens are controlled.
|
||||
|
||||
- `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
|
||||
"""
|
||||
maximum_tokens: Optional[int]=litellm.max_tokens # aleph alpha requires max tokens
|
||||
minimum_tokens: Optional[int]=None
|
||||
echo: Optional[bool]=None
|
||||
temperature: Optional[int]=None
|
||||
top_k: Optional[int]=None
|
||||
top_p: Optional[int]=None
|
||||
presence_penalty: Optional[int]=None
|
||||
frequency_penalty: Optional[int]=None
|
||||
sequence_penalty: Optional[int]=None
|
||||
sequence_penalty_min_length: Optional[int]=None
|
||||
repetition_penalties_include_prompt: Optional[bool]=None
|
||||
repetition_penalties_include_completion: Optional[bool]=None
|
||||
use_multiplicative_presence_penalty: Optional[bool]=None
|
||||
use_multiplicative_frequency_penalty: Optional[bool]=None
|
||||
use_multiplicative_sequence_penalty: Optional[bool]=None
|
||||
penalty_bias: Optional[str]=None
|
||||
penalty_exceptions_include_stop_sequences: Optional[bool]=None
|
||||
best_of: Optional[int]=None
|
||||
n: Optional[int]=None
|
||||
logit_bias: Optional[dict]=None
|
||||
log_probs: Optional[int]=None
|
||||
stop_sequences: Optional[list]=None
|
||||
tokens: Optional[bool]=None
|
||||
raw_completion: Optional[bool]=None
|
||||
disable_optimizations: Optional[bool]=None
|
||||
completion_bias_inclusion: Optional[list]=None
|
||||
completion_bias_exclusion: Optional[list]=None
|
||||
completion_bias_inclusion_first_token_only: Optional[bool]=None
|
||||
completion_bias_exclusion_first_token_only: Optional[bool]=None
|
||||
contextual_control_threshold: Optional[int]=None
|
||||
control_log_additive: Optional[bool]=None
|
||||
|
||||
maximum_tokens: Optional[
|
||||
int
|
||||
] = litellm.max_tokens # aleph alpha requires max tokens
|
||||
minimum_tokens: Optional[int] = None
|
||||
echo: Optional[bool] = None
|
||||
temperature: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
frequency_penalty: Optional[int] = None
|
||||
sequence_penalty: Optional[int] = None
|
||||
sequence_penalty_min_length: Optional[int] = None
|
||||
repetition_penalties_include_prompt: Optional[bool] = None
|
||||
repetition_penalties_include_completion: Optional[bool] = None
|
||||
use_multiplicative_presence_penalty: Optional[bool] = None
|
||||
use_multiplicative_frequency_penalty: Optional[bool] = None
|
||||
use_multiplicative_sequence_penalty: Optional[bool] = None
|
||||
penalty_bias: Optional[str] = None
|
||||
penalty_exceptions_include_stop_sequences: Optional[bool] = None
|
||||
best_of: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
log_probs: Optional[int] = None
|
||||
stop_sequences: Optional[list] = None
|
||||
tokens: Optional[bool] = None
|
||||
raw_completion: Optional[bool] = None
|
||||
disable_optimizations: Optional[bool] = None
|
||||
completion_bias_inclusion: Optional[list] = None
|
||||
completion_bias_exclusion: Optional[list] = None
|
||||
completion_bias_inclusion_first_token_only: Optional[bool] = None
|
||||
completion_bias_exclusion_first_token_only: Optional[bool] = None
|
||||
contextual_control_threshold: Optional[int] = None
|
||||
control_log_additive: Optional[bool] = None
|
||||
|
||||
def __init__(self,
|
||||
maximum_tokens: Optional[int]=None,
|
||||
minimum_tokens: Optional[int]=None,
|
||||
echo: Optional[bool]=None,
|
||||
temperature: Optional[int]=None,
|
||||
top_k: Optional[int]=None,
|
||||
top_p: Optional[int]=None,
|
||||
presence_penalty: Optional[int]=None,
|
||||
frequency_penalty: Optional[int]=None,
|
||||
sequence_penalty: Optional[int]=None,
|
||||
sequence_penalty_min_length: Optional[int]=None,
|
||||
repetition_penalties_include_prompt: Optional[bool]=None,
|
||||
repetition_penalties_include_completion: Optional[bool]=None,
|
||||
use_multiplicative_presence_penalty: Optional[bool]=None,
|
||||
use_multiplicative_frequency_penalty: Optional[bool]=None,
|
||||
use_multiplicative_sequence_penalty: Optional[bool]=None,
|
||||
penalty_bias: Optional[str]=None,
|
||||
penalty_exceptions_include_stop_sequences: Optional[bool]=None,
|
||||
best_of: Optional[int]=None,
|
||||
n: Optional[int]=None,
|
||||
logit_bias: Optional[dict]=None,
|
||||
log_probs: Optional[int]=None,
|
||||
stop_sequences: Optional[list]=None,
|
||||
tokens: Optional[bool]=None,
|
||||
raw_completion: Optional[bool]=None,
|
||||
disable_optimizations: Optional[bool]=None,
|
||||
completion_bias_inclusion: Optional[list]=None,
|
||||
completion_bias_exclusion: Optional[list]=None,
|
||||
completion_bias_inclusion_first_token_only: Optional[bool]=None,
|
||||
completion_bias_exclusion_first_token_only: Optional[bool]=None,
|
||||
contextual_control_threshold: Optional[int]=None,
|
||||
control_log_additive: Optional[bool]=None) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maximum_tokens: Optional[int] = None,
|
||||
minimum_tokens: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
sequence_penalty: Optional[int] = None,
|
||||
sequence_penalty_min_length: Optional[int] = None,
|
||||
repetition_penalties_include_prompt: Optional[bool] = None,
|
||||
repetition_penalties_include_completion: Optional[bool] = None,
|
||||
use_multiplicative_presence_penalty: Optional[bool] = None,
|
||||
use_multiplicative_frequency_penalty: Optional[bool] = None,
|
||||
use_multiplicative_sequence_penalty: Optional[bool] = None,
|
||||
penalty_bias: Optional[str] = None,
|
||||
penalty_exceptions_include_stop_sequences: Optional[bool] = None,
|
||||
best_of: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
log_probs: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
tokens: Optional[bool] = None,
|
||||
raw_completion: Optional[bool] = None,
|
||||
disable_optimizations: Optional[bool] = None,
|
||||
completion_bias_inclusion: Optional[list] = None,
|
||||
completion_bias_exclusion: Optional[list] = None,
|
||||
completion_bias_inclusion_first_token_only: Optional[bool] = None,
|
||||
completion_bias_exclusion_first_token_only: Optional[bool] = None,
|
||||
contextual_control_threshold: Optional[int] = None,
|
||||
control_log_additive: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
|
@ -160,6 +178,7 @@ def validate_environment(api_key):
|
|||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -177,9 +196,11 @@ def completion(
|
|||
headers = validate_environment(api_key)
|
||||
|
||||
## Load Config
|
||||
config = litellm.AlephAlphaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.AlephAlphaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
completion_url = api_base
|
||||
|
@ -188,21 +209,17 @@ def completion(
|
|||
if "control" in model: # follow the ###Instruction / ###Response format
|
||||
for idx, message in enumerate(messages):
|
||||
if "role" in message:
|
||||
if idx == 0: # set first message as instruction (required), let later user messages be input
|
||||
if (
|
||||
idx == 0
|
||||
): # set first message as instruction (required), let later user messages be input
|
||||
prompt += f"###Instruction: {message['content']}"
|
||||
else:
|
||||
if message["role"] == "system":
|
||||
prompt += (
|
||||
f"###Instruction: {message['content']}"
|
||||
)
|
||||
prompt += f"###Instruction: {message['content']}"
|
||||
elif message["role"] == "user":
|
||||
prompt += (
|
||||
f"###Input: {message['content']}"
|
||||
)
|
||||
prompt += f"###Input: {message['content']}"
|
||||
else:
|
||||
prompt += (
|
||||
f"###Response: {message['content']}"
|
||||
)
|
||||
prompt += f"###Response: {message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
|
@ -215,24 +232,27 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response.json()
|
||||
|
@ -247,18 +267,23 @@ def completion(
|
|||
for idx, item in enumerate(completion_response["completions"]):
|
||||
if len(item["completion"]) > 0:
|
||||
message_obj = Message(content=item["completion"])
|
||||
else:
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
|
||||
choice_obj = Choices(
|
||||
finish_reason=item["finish_reason"],
|
||||
index=idx + 1,
|
||||
message=message_obj,
|
||||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except:
|
||||
raise AlephAlphaError(message=json.dumps(completion_response), status_code=response.status_code)
|
||||
raise AlephAlphaError(
|
||||
message=json.dumps(completion_response),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||
)
|
||||
|
@ -268,11 +293,12 @@ def completion(
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -5,56 +5,76 @@ import requests
|
|||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
import litellm
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
import httpx
|
||||
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
|
||||
class AnthropicError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.anthropic.com/v1/complete")
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.anthropic.com/v1/complete"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class AnthropicConfig():
|
||||
|
||||
class AnthropicConfig:
|
||||
"""
|
||||
Reference: https://docs.anthropic.com/claude/reference/complete_post
|
||||
|
||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
max_tokens_to_sample: Optional[int]=litellm.max_tokens # anthropic requires a default
|
||||
stop_sequences: Optional[list]=None
|
||||
temperature: Optional[int]=None
|
||||
top_p: Optional[int]=None
|
||||
top_k: Optional[int]=None
|
||||
metadata: Optional[dict]=None
|
||||
|
||||
def __init__(self,
|
||||
max_tokens_to_sample: Optional[int]=256, # anthropic requires a default
|
||||
stop_sequences: Optional[list]=None,
|
||||
temperature: Optional[int]=None,
|
||||
top_p: Optional[int]=None,
|
||||
top_k: Optional[int]=None,
|
||||
metadata: Optional[dict]=None) -> None:
|
||||
|
||||
max_tokens_to_sample: Optional[
|
||||
int
|
||||
] = litellm.max_tokens # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
# makes headers for API call
|
||||
|
@ -71,6 +91,7 @@ def validate_environment(api_key):
|
|||
}
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -87,21 +108,25 @@ def completion(
|
|||
):
|
||||
headers = validate_environment(api_key)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
)
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
|
||||
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
|
@ -116,7 +141,7 @@ def completion(
|
|||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||
)
|
||||
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = requests.post(
|
||||
|
@ -125,18 +150,20 @@ def completion(
|
|||
data=json.dumps(data),
|
||||
stream=optional_params["stream"],
|
||||
)
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(status_code=response.status_code, message=response.text)
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
return response.iter_lines()
|
||||
else:
|
||||
response = requests.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(status_code=response.status_code, message=response.text)
|
||||
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
|
@ -159,9 +186,9 @@ def completion(
|
|||
)
|
||||
else:
|
||||
if len(completion_response["completion"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response[
|
||||
"completion"
|
||||
]
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["completion"]
|
||||
model_response.choices[0].finish_reason = completion_response["stop_reason"]
|
||||
|
||||
## CALCULATING USAGE
|
||||
|
@ -177,11 +204,12 @@ def completion(
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
from typing import Optional, Union, Any
|
||||
import types, requests
|
||||
from .base import BaseLLM
|
||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
Choices,
|
||||
Message,
|
||||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
)
|
||||
from typing import Callable, Optional
|
||||
from litellm import OpenAIConfig
|
||||
import litellm, json
|
||||
|
@ -9,8 +15,15 @@ import httpx
|
|||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
|
||||
def __init__(
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
if request:
|
||||
|
@ -20,11 +33,14 @@ class AzureOpenAIError(Exception):
|
|||
if response:
|
||||
self.response = response
|
||||
else:
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
self.response = httpx.Response(
|
||||
status_code=status_code, request=self.request
|
||||
)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class AzureOpenAIConfig(OpenAIConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
@ -49,33 +65,37 @@ class AzureOpenAIConfig(OpenAIConfig):
|
|||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]]= None,
|
||||
functions: Optional[list]= None,
|
||||
logit_bias: Optional[dict]= None,
|
||||
max_tokens: Optional[int]= None,
|
||||
n: Optional[int]= None,
|
||||
presence_penalty: Optional[int]= None,
|
||||
stop: Optional[Union[str,list]]=None,
|
||||
temperature: Optional[int]= None,
|
||||
top_p: Optional[int]= None) -> None:
|
||||
super().__init__(frequency_penalty,
|
||||
function_call,
|
||||
functions,
|
||||
logit_bias,
|
||||
max_tokens,
|
||||
n,
|
||||
presence_penalty,
|
||||
stop,
|
||||
temperature,
|
||||
top_p)
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
frequency_penalty,
|
||||
function_call,
|
||||
functions,
|
||||
logit_bias,
|
||||
max_tokens,
|
||||
n,
|
||||
presence_penalty,
|
||||
stop,
|
||||
temperature,
|
||||
top_p,
|
||||
)
|
||||
|
||||
|
||||
class AzureChatCompletion(BaseLLM):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -89,49 +109,51 @@ class AzureChatCompletion(BaseLLM):
|
|||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||
return headers
|
||||
|
||||
def completion(self,
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
api_type: str,
|
||||
azure_ad_token: str,
|
||||
print_verbose: Callable,
|
||||
timeout,
|
||||
logging_obj,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
logger_fn,
|
||||
acompletion: bool = False,
|
||||
headers: Optional[dict]=None,
|
||||
client = None,
|
||||
):
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
api_type: str,
|
||||
azure_ad_token: str,
|
||||
print_verbose: Callable,
|
||||
timeout,
|
||||
logging_obj,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
logger_fn,
|
||||
acompletion: bool = False,
|
||||
headers: Optional[dict] = None,
|
||||
client=None,
|
||||
):
|
||||
super().completion()
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
|
||||
if model is None or messages is None:
|
||||
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages")
|
||||
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message=f"Missing model or messages"
|
||||
)
|
||||
|
||||
max_retries = optional_params.pop("max_retries", 2)
|
||||
|
||||
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||
### if so - set the model as part of the base url
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
### if so - set the model as part of the base url
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
## build base url - assume api base includes resource name
|
||||
if client is None:
|
||||
if not api_base.endswith("/"):
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
api_base += f"{model}"
|
||||
|
||||
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"base_url": f"{api_base}",
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
|
@ -142,26 +164,53 @@ class AzureChatCompletion(BaseLLM):
|
|||
client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
client = AzureOpenAI(**azure_client_params)
|
||||
|
||||
|
||||
data = {"model": None, "messages": messages, **optional_params}
|
||||
else:
|
||||
data = {
|
||||
"model": None,
|
||||
"messages": messages,
|
||||
**optional_params
|
||||
"model": model, # type: ignore
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"model": model, # type: ignore
|
||||
"messages": messages,
|
||||
**optional_params
|
||||
}
|
||||
|
||||
if acompletion is True:
|
||||
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
|
||||
return self.async_streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
data=data,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, client=client, logging_obj=logging_obj)
|
||||
return self.acompletion(
|
||||
api_base=api_base,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
model=model,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
|
||||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
data=data,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
@ -169,16 +218,18 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": {
|
||||
"api_key": api_key,
|
||||
"azure_ad_token": azure_ad_token
|
||||
"api_key": api_key,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
},
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
|
@ -186,7 +237,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
|
@ -196,7 +247,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -209,30 +260,36 @@ class AzureChatCompletion(BaseLLM):
|
|||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||
except AzureOpenAIError as e:
|
||||
return convert_to_model_response_object(
|
||||
response_object=json.loads(stringified_response),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def acompletion(self,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
model: str,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
timeout: Any,
|
||||
model_response: ModelResponse,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
client = None, # this is the AsyncAzureOpenAI
|
||||
logging_obj=None,
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
model: str,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
timeout: Any,
|
||||
model_response: ModelResponse,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None, # this is the AsyncAzureOpenAI
|
||||
logging_obj=None,
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
|
@ -240,7 +297,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
|
@ -252,35 +309,46 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data['messages'],
|
||||
input=data["messages"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
except AzureOpenAIError as e:
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
return convert_to_model_response_object(
|
||||
response_object=json.loads(response.model_dump_json()),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise e
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
|
||||
def streaming(self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
client=None,
|
||||
):
|
||||
def streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
):
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
|
@ -288,7 +356,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
|
@ -300,25 +368,36 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data['messages'],
|
||||
input=data["messages"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
response = azure_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="azure",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
async def async_streaming(self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
client = None,
|
||||
):
|
||||
async def async_streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
):
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
|
@ -326,39 +405,49 @@ class AzureChatCompletion(BaseLLM):
|
|||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": data.pop("max_retries", 2),
|
||||
"timeout": timeout
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data['messages'],
|
||||
input=data["messages"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="azure",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
self,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
azure_client_params: dict,
|
||||
api_key: str,
|
||||
input: list,
|
||||
api_key: str,
|
||||
input: list,
|
||||
client=None,
|
||||
logging_obj=None
|
||||
):
|
||||
logging_obj=None,
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
try:
|
||||
if client is None:
|
||||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
|
@ -367,50 +456,53 @@ class AzureChatCompletion(BaseLLM):
|
|||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(
|
||||
response_object=json.loads(stringified_response),
|
||||
model_response_object=model_response,
|
||||
response_type="embedding",
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
input: list,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
timeout: float,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
optional_params=None,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
client = None,
|
||||
aembedding=None,
|
||||
):
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
timeout: float,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
optional_params=None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
):
|
||||
super().embedding()
|
||||
exception_mapping_worked = False
|
||||
if self._client_session is None:
|
||||
self._client_session = self.create_client_session()
|
||||
try:
|
||||
data = {
|
||||
"model": model,
|
||||
"input": input,
|
||||
**optional_params
|
||||
}
|
||||
try:
|
||||
data = {"model": model, "input": input, **optional_params}
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
|
@ -418,7 +510,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
|
@ -427,119 +519,130 @@ class AzureChatCompletion(BaseLLM):
|
|||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": {
|
||||
"api_key": api_key,
|
||||
"azure_ad_token": azure_ad_token
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": {"api_key": api_key, "azure_ad_token": azure_ad_token},
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params)
|
||||
response = self.aembedding(
|
||||
data=data,
|
||||
input=input,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
model_response=model_response,
|
||||
azure_client_params=azure_client_params,
|
||||
)
|
||||
return response
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
## COMPLETION CALL
|
||||
response = azure_client.embeddings.create(**data) # type: ignore
|
||||
## COMPLETION CALL
|
||||
response = azure_client.embeddings.create(**data) # type: ignore
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||
original_response=response,
|
||||
)
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore
|
||||
except AzureOpenAIError as e:
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
else:
|
||||
else:
|
||||
import traceback
|
||||
|
||||
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
self,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
azure_client_params: dict,
|
||||
api_key: str,
|
||||
input: list,
|
||||
api_key: str,
|
||||
input: list,
|
||||
client=None,
|
||||
logging_obj=None
|
||||
):
|
||||
logging_obj=None,
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
try:
|
||||
if client is None:
|
||||
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),)
|
||||
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params)
|
||||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
||||
transport=AsyncCustomHTTPTransport(),
|
||||
)
|
||||
openai_aclient = AsyncAzureOpenAI(
|
||||
http_client=client_session, **azure_client_params
|
||||
)
|
||||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.images.generate(**data)
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation")
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(
|
||||
response_object=json.loads(stringified_response),
|
||||
model_response_object=model_response,
|
||||
response_type="image_generation",
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
def image_generation(self,
|
||||
prompt: str,
|
||||
timeout: float,
|
||||
model: Optional[str]=None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
aimg_generation=None,
|
||||
):
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
timeout: float,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
aimg_generation=None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
try:
|
||||
if model and len(model) > 0:
|
||||
model = model
|
||||
else:
|
||||
model = None
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
data = {"model": model, "prompt": prompt, **optional_params}
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
|
@ -547,39 +650,47 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if aimg_generation == True:
|
||||
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
|
||||
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
|
||||
return response
|
||||
|
||||
|
||||
if client is None:
|
||||
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
|
||||
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
|
||||
client_session = litellm.client_session or httpx.Client(
|
||||
transport=CustomHTTPTransport(),
|
||||
)
|
||||
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data},
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": False,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
## COMPLETION CALL
|
||||
response = azure_client.images.generate(**data) # type: ignore
|
||||
response = azure_client.images.generate(**data) # type: ignore
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
# return response
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||
except AzureOpenAIError as e:
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
else:
|
||||
else:
|
||||
import traceback
|
||||
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
|
|
@ -1,47 +1,45 @@
|
|||
## This is a template base class to be used for adding new LLM providers via API calls
|
||||
import litellm
|
||||
import litellm
|
||||
import httpx, certifi, ssl
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
_client_session: Optional[httpx.Client] = None
|
||||
|
||||
def create_client_session(self):
|
||||
if litellm.client_session:
|
||||
if litellm.client_session:
|
||||
_client_session = litellm.client_session
|
||||
else:
|
||||
else:
|
||||
_client_session = httpx.Client()
|
||||
|
||||
|
||||
return _client_session
|
||||
|
||||
|
||||
def create_aclient_session(self):
|
||||
if litellm.aclient_session:
|
||||
if litellm.aclient_session:
|
||||
_aclient_session = litellm.aclient_session
|
||||
else:
|
||||
else:
|
||||
_aclient_session = httpx.AsyncClient()
|
||||
|
||||
|
||||
return _aclient_session
|
||||
|
||||
|
||||
def __exit__(self):
|
||||
if hasattr(self, '_client_session'):
|
||||
if hasattr(self, "_client_session"):
|
||||
self._client_session.close()
|
||||
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if hasattr(self, '_aclient_session'):
|
||||
if hasattr(self, "_aclient_session"):
|
||||
await self._aclient_session.aclose()
|
||||
|
||||
def validate_environment(self): # set up the environment required to run the model
|
||||
pass
|
||||
|
||||
def completion(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
self, *args, **kwargs
|
||||
): # logic for parsing in - calling - parsing out model completion calls
|
||||
pass
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
self, *args, **kwargs
|
||||
): # logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -6,6 +6,7 @@ import time
|
|||
from typing import Callable
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
|
||||
|
||||
class BasetenError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
|
@ -14,6 +15,7 @@ class BasetenError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
|
@ -23,6 +25,7 @@ def validate_environment(api_key):
|
|||
headers["Authorization"] = f"Api-Key {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -52,32 +55,38 @@ def completion(
|
|||
"inputs": prompt,
|
||||
"prompt": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False
|
||||
"stream": True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
completion_url_fragment_1 + model + completion_url_fragment_2,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=True if "stream" in optional_params and optional_params["stream"] == True else False
|
||||
stream=True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
)
|
||||
if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True):
|
||||
if "text/event-stream" in response.headers["Content-Type"] or (
|
||||
"stream" in optional_params and optional_params["stream"] == True
|
||||
):
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response.json()
|
||||
|
@ -91,9 +100,7 @@ def completion(
|
|||
if (
|
||||
isinstance(completion_response["model_output"], dict)
|
||||
and "data" in completion_response["model_output"]
|
||||
and isinstance(
|
||||
completion_response["model_output"]["data"], list
|
||||
)
|
||||
and isinstance(completion_response["model_output"]["data"], list)
|
||||
):
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
|
@ -112,12 +119,19 @@ def completion(
|
|||
if "generated_text" not in completion_response:
|
||||
raise BasetenError(
|
||||
message=f"Unable to parse response. Original response: {response.text}",
|
||||
status_code=response.status_code
|
||||
status_code=response.status_code,
|
||||
)
|
||||
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
|
||||
## GETTING LOGPROBS
|
||||
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
|
||||
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response[0]["generated_text"]
|
||||
## GETTING LOGPROBS
|
||||
if (
|
||||
"details" in completion_response[0]
|
||||
and "tokens" in completion_response[0]["details"]
|
||||
):
|
||||
model_response.choices[0].finish_reason = completion_response[0][
|
||||
"details"
|
||||
]["finish_reason"]
|
||||
sum_logprob = 0
|
||||
for token in completion_response[0]["details"]["tokens"]:
|
||||
sum_logprob += token["logprob"]
|
||||
|
@ -125,7 +139,7 @@ def completion(
|
|||
else:
|
||||
raise BasetenError(
|
||||
message=f"Unable to parse response. Original response: {response.text}",
|
||||
status_code=response.status_code
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
|
@ -139,11 +153,12 @@ def completion(
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, get_secret, Usage
|
|||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
import httpx
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock")
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class AmazonTitanConfig():
|
||||
|
||||
class AmazonTitanConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||
|
||||
|
@ -29,29 +33,44 @@ class AmazonTitanConfig():
|
|||
- `temperature` (float) temperature for model,
|
||||
- `topP` (int) top p for model
|
||||
"""
|
||||
maxTokenCount: Optional[int]=None
|
||||
stopSequences: Optional[list]=None
|
||||
temperature: Optional[float]=None
|
||||
topP: Optional[int]=None
|
||||
|
||||
def __init__(self,
|
||||
maxTokenCount: Optional[int]=None,
|
||||
stopSequences: Optional[list]=None,
|
||||
temperature: Optional[float]=None,
|
||||
topP: Optional[int]=None) -> None:
|
||||
maxTokenCount: Optional[int] = None
|
||||
stopSequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
class AmazonAnthropicConfig():
|
||||
|
||||
class AmazonAnthropicConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
|
||||
|
@ -64,33 +83,48 @@ class AmazonAnthropicConfig():
|
|||
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
||||
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
"""
|
||||
max_tokens_to_sample: Optional[int]=litellm.max_tokens
|
||||
stop_sequences: Optional[list]=None
|
||||
temperature: Optional[float]=None
|
||||
top_k: Optional[int]=None
|
||||
top_p: Optional[int]=None
|
||||
anthropic_version: Optional[str]=None
|
||||
|
||||
def __init__(self,
|
||||
max_tokens_to_sample: Optional[int]=None,
|
||||
stop_sequences: Optional[list]=None,
|
||||
temperature: Optional[float]=None,
|
||||
top_k: Optional[int]=None,
|
||||
top_p: Optional[int]=None,
|
||||
anthropic_version: Optional[str]=None) -> None:
|
||||
max_tokens_to_sample: Optional[int] = litellm.max_tokens
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
anthropic_version: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
class AmazonCohereConfig():
|
||||
|
||||
class AmazonCohereConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||
|
||||
|
@ -100,79 +134,110 @@ class AmazonCohereConfig():
|
|||
- `temperature` (float) model temperature,
|
||||
- `return_likelihood` (string) n/a
|
||||
"""
|
||||
max_tokens: Optional[int]=None
|
||||
temperature: Optional[float]=None
|
||||
return_likelihood: Optional[str]=None
|
||||
|
||||
def __init__(self,
|
||||
max_tokens: Optional[int]=None,
|
||||
temperature: Optional[float]=None,
|
||||
return_likelihood: Optional[str]=None) -> None:
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
return_likelihood: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
return_likelihood: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
class AmazonAI21Config():
|
||||
|
||||
class AmazonAI21Config:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
|
||||
Supported Params for the Amazon / AI21 models:
|
||||
|
||||
|
||||
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
||||
|
||||
|
||||
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
||||
|
||||
|
||||
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
||||
|
||||
|
||||
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
||||
|
||||
|
||||
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
||||
|
||||
|
||||
- `presencePenalty` (object): Placeholder for presence penalty object.
|
||||
|
||||
|
||||
- `countPenalty` (object): Placeholder for count penalty object.
|
||||
"""
|
||||
maxTokens: Optional[int]=None
|
||||
temperature: Optional[float]=None
|
||||
topP: Optional[float]=None
|
||||
stopSequences: Optional[list]=None
|
||||
frequencePenalty: Optional[dict]=None
|
||||
presencePenalty: Optional[dict]=None
|
||||
countPenalty: Optional[dict]=None
|
||||
|
||||
def __init__(self,
|
||||
maxTokens: Optional[int]=None,
|
||||
temperature: Optional[float]=None,
|
||||
topP: Optional[float]=None,
|
||||
stopSequences: Optional[list]=None,
|
||||
frequencePenalty: Optional[dict]=None,
|
||||
presencePenalty: Optional[dict]=None,
|
||||
countPenalty: Optional[dict]=None) -> None:
|
||||
maxTokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
stopSequences: Optional[list] = None
|
||||
frequencePenalty: Optional[dict] = None
|
||||
presencePenalty: Optional[dict] = None
|
||||
countPenalty: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[float] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
frequencePenalty: Optional[dict] = None,
|
||||
presencePenalty: Optional[dict] = None,
|
||||
countPenalty: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
class AmazonLlamaConfig():
|
||||
|
||||
class AmazonLlamaConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
||||
|
||||
|
@ -182,48 +247,72 @@ class AmazonLlamaConfig():
|
|||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
"""
|
||||
max_gen_len: Optional[int]=None
|
||||
temperature: Optional[float]=None
|
||||
topP: Optional[float]=None
|
||||
|
||||
def __init__(self,
|
||||
maxTokenCount: Optional[int]=None,
|
||||
temperature: Optional[float]=None,
|
||||
topP: Optional[int]=None) -> None:
|
||||
max_gen_len: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def init_bedrock_client(
|
||||
region_name = None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] =None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str]=None,
|
||||
):
|
||||
region_name=None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
||||
):
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint]
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith('os.environ/'):
|
||||
if param and param.startswith("os.environ/"):
|
||||
params_to_check[i] = get_secret(param)
|
||||
# Assign updated values back to parameters
|
||||
aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
) = params_to_check
|
||||
if region_name:
|
||||
pass
|
||||
elif aws_region_name:
|
||||
|
@ -233,7 +322,10 @@ def init_bedrock_client(
|
|||
elif standard_aws_region_name:
|
||||
region_name = standard_aws_region_name
|
||||
else:
|
||||
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", status_code=401)
|
||||
raise BedrockError(
|
||||
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
|
@ -242,9 +334,10 @@ def init_bedrock_client(
|
|||
elif env_aws_bedrock_runtime_endpoint:
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f'https://bedrock-runtime.{region_name}.amazonaws.com'
|
||||
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
|
||||
|
||||
import boto3
|
||||
|
||||
if aws_access_key_id != None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
|
@ -257,7 +350,7 @@ def init_bedrock_client(
|
|||
endpoint_url=endpoint_url,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automatically reads env variables
|
||||
|
||||
client = boto3.client(
|
||||
|
@ -276,25 +369,23 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
|||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
return prompt
|
||||
|
@ -309,17 +400,18 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
|
|||
|
||||
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
|
@ -327,7 +419,9 @@ def completion(
|
|||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = optional_params.pop(
|
||||
|
@ -343,67 +437,71 @@ def completion(
|
|||
|
||||
model = model
|
||||
provider = model.split(".")[0]
|
||||
prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict)
|
||||
prompt = convert_messages_to_prompt(
|
||||
model, messages, provider, custom_prompt_dict
|
||||
)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
stream = inference_params.pop("stream", False)
|
||||
if provider == "anthropic":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.AmazonAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = json.dumps({
|
||||
"prompt": prompt,
|
||||
**inference_params
|
||||
})
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "ai21":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
data = json.dumps({
|
||||
"prompt": prompt,
|
||||
**inference_params
|
||||
})
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "cohere":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.AmazonCohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
if optional_params.get("stream", False) == True:
|
||||
inference_params["stream"] = True # cohere requires stream = True in inference params
|
||||
data = json.dumps({
|
||||
"prompt": prompt,
|
||||
**inference_params
|
||||
})
|
||||
inference_params[
|
||||
"stream"
|
||||
] = True # cohere requires stream = True in inference params
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "meta":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonLlamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = json.dumps({
|
||||
"prompt": prompt,
|
||||
**inference_params
|
||||
})
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "amazon": # amazon titan
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonTitanConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.AmazonTitanConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
data = json.dumps({
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": inference_params,
|
||||
})
|
||||
|
||||
data = json.dumps(
|
||||
{
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": inference_params,
|
||||
}
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
accept = 'application/json'
|
||||
contentType = 'application/json'
|
||||
accept = "application/json"
|
||||
contentType = "application/json"
|
||||
if stream == True:
|
||||
if provider == "ai21":
|
||||
## LOGGING
|
||||
|
@ -418,17 +516,17 @@ def completion(
|
|||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
response = client.invoke_model(
|
||||
body=data,
|
||||
modelId=model,
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
body=data, modelId=model, accept=accept, contentType=contentType
|
||||
)
|
||||
|
||||
response = response.get('body').read()
|
||||
response = response.get("body").read()
|
||||
return response
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -441,20 +539,20 @@ def completion(
|
|||
)
|
||||
"""
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
response = client.invoke_model_with_response_stream(
|
||||
body=data,
|
||||
modelId=model,
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
body=data, modelId=model, accept=accept, contentType=contentType
|
||||
)
|
||||
response = response.get('body')
|
||||
response = response.get("body")
|
||||
return response
|
||||
try:
|
||||
try:
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
|
@ -465,20 +563,20 @@ def completion(
|
|||
)
|
||||
"""
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
)
|
||||
response = client.invoke_model(
|
||||
body=data,
|
||||
modelId=model,
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
response = client.invoke_model(
|
||||
body=data, modelId=model, accept=accept, contentType=contentType
|
||||
)
|
||||
except Exception as e:
|
||||
raise BedrockError(status_code=500, message=str(e))
|
||||
|
||||
response_body = json.loads(response.get('body').read())
|
||||
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -491,16 +589,16 @@ def completion(
|
|||
## RESPONSE OBJECT
|
||||
outputText = "default"
|
||||
if provider == "ai21":
|
||||
outputText = response_body.get('completions')[0].get('data').get('text')
|
||||
outputText = response_body.get("completions")[0].get("data").get("text")
|
||||
elif provider == "anthropic":
|
||||
outputText = response_body['completion']
|
||||
outputText = response_body["completion"]
|
||||
model_response["finish_reason"] = response_body["stop_reason"]
|
||||
elif provider == "cohere":
|
||||
elif provider == "cohere":
|
||||
outputText = response_body["generations"][0]["text"]
|
||||
elif provider == "meta":
|
||||
elif provider == "meta":
|
||||
outputText = response_body["generation"]
|
||||
else: # amazon titan
|
||||
outputText = response_body.get('results')[0].get('outputText')
|
||||
outputText = response_body.get("results")[0].get("outputText")
|
||||
|
||||
response_metadata = response.get("ResponseMetadata", {})
|
||||
if response_metadata.get("HTTPStatusCode", 500) >= 400:
|
||||
|
@ -513,12 +611,13 @@ def completion(
|
|||
if len(outputText) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = outputText
|
||||
except:
|
||||
raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500))
|
||||
raise BedrockError(
|
||||
message=json.dumps(outputText),
|
||||
status_code=response_metadata.get("HTTPStatusCode", 500),
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
@ -528,41 +627,47 @@ def completion(
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
except BedrockError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
else:
|
||||
else:
|
||||
import traceback
|
||||
|
||||
raise BedrockError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
|
||||
def _embedding_func_single(
|
||||
model: str,
|
||||
input: str,
|
||||
client: Any,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
logging_obj=None,
|
||||
model: str,
|
||||
input: str,
|
||||
client: Any,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
logging_obj=None,
|
||||
):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
## FORMAT EMBEDDING INPUT ##
|
||||
## FORMAT EMBEDDING INPUT ##
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop("user", None) # make sure user is not passed in for bedrock call
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
if provider == "amazon":
|
||||
input = input.replace(os.linesep, " ")
|
||||
data = {"inputText": input, **inference_params}
|
||||
# data = json.dumps(data)
|
||||
elif provider == "cohere":
|
||||
inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||
data = {"texts": [input], **inference_params} # type: ignore
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
inference_params["input_type"] = inference_params.get(
|
||||
"input_type", "search_document"
|
||||
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||
data = {"texts": [input], **inference_params} # type: ignore
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
|
@ -570,12 +675,14 @@ def _embedding_func_single(
|
|||
modelId={model},
|
||||
accept="*/*",
|
||||
contentType="application/json",
|
||||
)""" # type: ignore
|
||||
)""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="", # boto3 is used for init.
|
||||
additional_args={"complete_input_dict": {"model": model,
|
||||
"texts": input}, "request_str": request_str},
|
||||
api_key="", # boto3 is used for init.
|
||||
additional_args={
|
||||
"complete_input_dict": {"model": model, "texts": input},
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
try:
|
||||
response = client.invoke_model(
|
||||
|
@ -587,11 +694,11 @@ def _embedding_func_single(
|
|||
response_body = json.loads(response.get("body").read())
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=json.dumps(response_body),
|
||||
)
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=json.dumps(response_body),
|
||||
)
|
||||
if provider == "cohere":
|
||||
response = response_body.get("embeddings")
|
||||
# flatten list
|
||||
|
@ -600,7 +707,10 @@ def _embedding_func_single(
|
|||
elif provider == "amazon":
|
||||
return response_body.get("embedding")
|
||||
except Exception as e:
|
||||
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
|
||||
raise BedrockError(
|
||||
message=f"Embedding Error with model {model}: {e}", status_code=500
|
||||
)
|
||||
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
|
@ -616,7 +726,9 @@ def embedding(
|
|||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
|
@ -624,11 +736,19 @@ def embedding(
|
|||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
)
|
||||
)
|
||||
|
||||
## Embedding Call
|
||||
embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls
|
||||
|
||||
embeddings = [
|
||||
_embedding_func_single(
|
||||
model,
|
||||
i,
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
for i in input
|
||||
] # [TODO]: make these parallel calls
|
||||
|
||||
## Populate OpenAI compliant dictionary
|
||||
embedding_response = []
|
||||
|
@ -647,13 +767,11 @@ def embedding(
|
|||
|
||||
input_str = "".join(input)
|
||||
|
||||
input_tokens+=len(encoding.encode(input_str))
|
||||
input_tokens += len(encoding.encode(input_str))
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=input_tokens + 0
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + 0
|
||||
)
|
||||
model_response.usage = usage
|
||||
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -8,88 +8,106 @@ from litellm.utils import ModelResponse, Choices, Message, Usage
|
|||
import litellm
|
||||
import httpx
|
||||
|
||||
|
||||
class CohereError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/generate")
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.cohere.ai/v1/generate"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class CohereConfig():
|
||||
|
||||
class CohereConfig:
|
||||
"""
|
||||
Reference: https://docs.cohere.com/reference/generate
|
||||
|
||||
The class `CohereConfig` provides configuration for the Cohere's API interface. Below are the parameters:
|
||||
|
||||
|
||||
- `num_generations` (integer): Maximum number of generations returned. Default is 1, with a minimum value of 1 and a maximum value of 5.
|
||||
|
||||
|
||||
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default value is 20.
|
||||
|
||||
|
||||
- `truncate` (string): Specifies how the API handles inputs longer than maximum token length. Options include NONE, START, END. Default is END.
|
||||
|
||||
|
||||
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.75.
|
||||
|
||||
|
||||
- `preset` (string): Identifier of a custom preset, a combination of parameters such as prompt, temperature etc.
|
||||
|
||||
|
||||
- `end_sequences` (array of strings): The generated text gets cut at the beginning of the earliest occurrence of an end sequence, which will be excluded from the text.
|
||||
|
||||
|
||||
- `stop_sequences` (array of strings): The generated text gets cut at the end of the earliest occurrence of a stop sequence, which will be included in the text.
|
||||
|
||||
|
||||
- `k` (integer): Limits generation at each step to top `k` most likely tokens. Default is 0.
|
||||
|
||||
|
||||
- `p` (number): Limits generation at each step to most likely tokens with total probability mass of `p`. Default is 0.
|
||||
|
||||
|
||||
- `frequency_penalty` (number): Reduces repetitiveness of generated tokens. Higher values apply stronger penalties to previously occurred tokens.
|
||||
|
||||
|
||||
- `presence_penalty` (number): Reduces repetitiveness of generated tokens. Similar to frequency_penalty, but this penalty applies equally to all tokens that have already appeared.
|
||||
|
||||
|
||||
- `return_likelihoods` (string): Specifies how and if token likelihoods are returned with the response. Options include GENERATION, ALL and NONE.
|
||||
|
||||
|
||||
- `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233}
|
||||
"""
|
||||
num_generations: Optional[int]=None
|
||||
max_tokens: Optional[int]=None
|
||||
truncate: Optional[str]=None
|
||||
temperature: Optional[int]=None
|
||||
preset: Optional[str]=None
|
||||
end_sequences: Optional[list]=None
|
||||
stop_sequences: Optional[list]=None
|
||||
k: Optional[int]=None
|
||||
p: Optional[int]=None
|
||||
frequency_penalty: Optional[int]=None
|
||||
presence_penalty: Optional[int]=None
|
||||
return_likelihoods: Optional[str]=None
|
||||
logit_bias: Optional[dict]=None
|
||||
|
||||
def __init__(self,
|
||||
num_generations: Optional[int]=None,
|
||||
max_tokens: Optional[int]=None,
|
||||
truncate: Optional[str]=None,
|
||||
temperature: Optional[int]=None,
|
||||
preset: Optional[str]=None,
|
||||
end_sequences: Optional[list]=None,
|
||||
stop_sequences: Optional[list]=None,
|
||||
k: Optional[int]=None,
|
||||
p: Optional[int]=None,
|
||||
frequency_penalty: Optional[int]=None,
|
||||
presence_penalty: Optional[int]=None,
|
||||
return_likelihoods: Optional[str]=None,
|
||||
logit_bias: Optional[dict]=None) -> None:
|
||||
|
||||
|
||||
num_generations: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
truncate: Optional[str] = None
|
||||
temperature: Optional[int] = None
|
||||
preset: Optional[str] = None
|
||||
end_sequences: Optional[list] = None
|
||||
stop_sequences: Optional[list] = None
|
||||
k: Optional[int] = None
|
||||
p: Optional[int] = None
|
||||
frequency_penalty: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
return_likelihoods: Optional[str] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_generations: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
truncate: Optional[str] = None,
|
||||
temperature: Optional[int] = None,
|
||||
preset: Optional[str] = None,
|
||||
end_sequences: Optional[list] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
k: Optional[int] = None,
|
||||
p: Optional[int] = None,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
return_likelihoods: Optional[str] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
|
@ -100,6 +118,7 @@ def validate_environment(api_key):
|
|||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -119,9 +138,11 @@ def completion(
|
|||
prompt = " ".join(message["content"] for message in messages)
|
||||
|
||||
## Load Config
|
||||
config=litellm.CohereConfig.get_config()
|
||||
config = litellm.CohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
|
@ -132,16 +153,23 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url},
|
||||
)
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": completion_url,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||
)
|
||||
## error handling for cohere calls
|
||||
if response.status_code!=200:
|
||||
if response.status_code != 200:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
|
@ -149,11 +177,11 @@ def completion(
|
|||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response.json()
|
||||
|
@ -168,18 +196,22 @@ def completion(
|
|||
for idx, item in enumerate(completion_response["generations"]):
|
||||
if len(item["text"]) > 0:
|
||||
message_obj = Message(content=item["text"])
|
||||
else:
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
|
||||
choice_obj = Choices(
|
||||
finish_reason=item["finish_reason"],
|
||||
index=idx + 1,
|
||||
message=message_obj,
|
||||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
raise CohereError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
@ -189,11 +221,12 @@ def completion(
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
input: list,
|
||||
|
@ -206,11 +239,7 @@ def embedding(
|
|||
headers = validate_environment(api_key)
|
||||
embed_url = "https://api.cohere.ai/v1/embed"
|
||||
model = model
|
||||
data = {
|
||||
"model": model,
|
||||
"texts": input,
|
||||
**optional_params
|
||||
}
|
||||
data = {"model": model, "texts": input, **optional_params}
|
||||
|
||||
if "3" in model and "input_type" not in data:
|
||||
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
|
||||
|
@ -218,21 +247,19 @@ def embedding(
|
|||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
embed_url, headers=headers, data=json.dumps(data)
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
"""
|
||||
response
|
||||
{
|
||||
|
@ -244,30 +271,23 @@ def embedding(
|
|||
'usage'
|
||||
}
|
||||
"""
|
||||
if response.status_code!=200:
|
||||
if response.status_code != 200:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
embeddings = response.json()['embeddings']
|
||||
embeddings = response.json()["embeddings"]
|
||||
output_data = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding
|
||||
}
|
||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = output_data
|
||||
model_response["model"] = model
|
||||
input_tokens = 0
|
||||
for text in input:
|
||||
input_tokens+=len(encoding.encode(text))
|
||||
input_tokens += len(encoding.encode(text))
|
||||
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": input_tokens,
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": input_tokens,
|
||||
"total_tokens": input_tokens,
|
||||
}
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
@ -1,20 +1,24 @@
|
|||
import time, json, httpx, asyncio
|
||||
|
||||
|
||||
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
|
||||
"""
|
||||
Async implementation of custom http transport
|
||||
"""
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
if "images/generations" in request.url.path and request.url.params[
|
||||
"api-version"
|
||||
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||
"2023-06-01-preview",
|
||||
"2023-07-01-preview",
|
||||
"2023-08-01-preview",
|
||||
"2023-09-01-preview",
|
||||
"2023-10-01-preview",
|
||||
"2023-10-01-preview",
|
||||
]:
|
||||
request.url = request.url.copy_with(path="/openai/images/generations:submit")
|
||||
request.url = request.url.copy_with(
|
||||
path="/openai/images/generations:submit"
|
||||
)
|
||||
response = await super().handle_async_request(request)
|
||||
operation_location_url = response.headers["operation-location"]
|
||||
request.url = httpx.URL(operation_location_url)
|
||||
|
@ -26,7 +30,12 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
|
|||
start_time = time.time()
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
|
||||
timeout = {
|
||||
"error": {
|
||||
"code": "Timeout",
|
||||
"message": "Operation polling timed out.",
|
||||
}
|
||||
}
|
||||
return httpx.Response(
|
||||
status_code=400,
|
||||
headers=response.headers,
|
||||
|
@ -56,26 +65,30 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
|
|||
)
|
||||
return await super().handle_async_request(request)
|
||||
|
||||
|
||||
class CustomHTTPTransport(httpx.HTTPTransport):
|
||||
"""
|
||||
This class was written as a workaround to support dall-e-2 on openai > v1.x
|
||||
|
||||
Refer to this issue for more: https://github.com/openai/openai-python/issues/692
|
||||
"""
|
||||
|
||||
def handle_request(
|
||||
self,
|
||||
request: httpx.Request,
|
||||
) -> httpx.Response:
|
||||
if "images/generations" in request.url.path and request.url.params[
|
||||
"api-version"
|
||||
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||
"2023-06-01-preview",
|
||||
"2023-07-01-preview",
|
||||
"2023-08-01-preview",
|
||||
"2023-09-01-preview",
|
||||
"2023-10-01-preview",
|
||||
"2023-10-01-preview",
|
||||
]:
|
||||
request.url = request.url.copy_with(path="/openai/images/generations:submit")
|
||||
request.url = request.url.copy_with(
|
||||
path="/openai/images/generations:submit"
|
||||
)
|
||||
response = super().handle_request(request)
|
||||
operation_location_url = response.headers["operation-location"]
|
||||
request.url = httpx.URL(operation_location_url)
|
||||
|
@ -87,7 +100,12 @@ class CustomHTTPTransport(httpx.HTTPTransport):
|
|||
start_time = time.time()
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
|
||||
timeout = {
|
||||
"error": {
|
||||
"code": "Timeout",
|
||||
"message": "Operation polling timed out.",
|
||||
}
|
||||
}
|
||||
return httpx.Response(
|
||||
status_code=400,
|
||||
headers=response.headers,
|
||||
|
@ -115,4 +133,4 @@ class CustomHTTPTransport(httpx.HTTPTransport):
|
|||
content=json.dumps(result).encode("utf-8"),
|
||||
request=request,
|
||||
)
|
||||
return super().handle_request(request)
|
||||
return super().handle_request(request)
|
||||
|
|
|
@ -8,17 +8,22 @@ import litellm
|
|||
import sys, httpx
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
class GeminiError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat")
|
||||
self.request = httpx.Request(
|
||||
method="POST",
|
||||
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class GeminiConfig():
|
||||
|
||||
class GeminiConfig:
|
||||
"""
|
||||
Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig
|
||||
|
||||
|
@ -37,33 +42,44 @@ class GeminiConfig():
|
|||
- `top_k` (int): Optional. The maximum number of tokens to consider when sampling.
|
||||
"""
|
||||
|
||||
candidate_count: Optional[int]=None
|
||||
stop_sequences: Optional[list]=None
|
||||
max_output_tokens: Optional[int]=None
|
||||
temperature: Optional[float]=None
|
||||
top_p: Optional[float]=None
|
||||
top_k: Optional[int]=None
|
||||
candidate_count: Optional[int] = None
|
||||
stop_sequences: Optional[list] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
|
||||
def __init__(self,
|
||||
candidate_count: Optional[int]=None,
|
||||
stop_sequences: Optional[list]=None,
|
||||
max_output_tokens: Optional[int]=None,
|
||||
temperature: Optional[float]=None,
|
||||
top_p: Optional[float]=None,
|
||||
top_k: Optional[int]=None) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
candidate_count: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def completion(
|
||||
|
@ -83,42 +99,50 @@ def completion(
|
|||
try:
|
||||
import google.generativeai as genai
|
||||
except:
|
||||
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai")
|
||||
raise Exception(
|
||||
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||
)
|
||||
genai.configure(api_key=api_key)
|
||||
|
||||
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
)
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="gemini")
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="gemini"
|
||||
)
|
||||
|
||||
|
||||
## Load Config
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py
|
||||
config = litellm.GeminiConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params.pop(
|
||||
"stream", None
|
||||
) # palm does not support streaming, so we handle this by fake streaming in main.py
|
||||
config = litellm.GeminiConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": {"inference_params": inference_params}},
|
||||
)
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": {"inference_params": inference_params}},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
_model = genai.GenerativeModel(f'models/{model}')
|
||||
response = _model.generate_content(contents=prompt, generation_config=genai.types.GenerationConfig(**inference_params))
|
||||
try:
|
||||
_model = genai.GenerativeModel(f"models/{model}")
|
||||
response = _model.generate_content(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
)
|
||||
except Exception as e:
|
||||
raise GeminiError(
|
||||
message=str(e),
|
||||
|
@ -127,11 +151,11 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": {}},
|
||||
)
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": {}},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response
|
||||
|
@ -142,31 +166,34 @@ def completion(
|
|||
message_obj = Message(content=item.content.parts[0].text)
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(index=idx+1, message=message_obj)
|
||||
choice_obj = Choices(index=idx + 1, message=message_obj)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise GeminiError(message=traceback.format_exc(), status_code=response.status_code)
|
||||
|
||||
try:
|
||||
raise GeminiError(
|
||||
message=traceback.format_exc(), status_code=response.status_code
|
||||
)
|
||||
|
||||
try:
|
||||
completion_response = model_response["choices"][0]["message"].get("content")
|
||||
except:
|
||||
raise GeminiError(status_code=400, message=f"No response received. Original response - {response}")
|
||||
raise GeminiError(
|
||||
status_code=400,
|
||||
message=f"No response received. Original response - {response}",
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_str = ""
|
||||
prompt_str = ""
|
||||
for m in messages:
|
||||
if isinstance(m["content"], str):
|
||||
prompt_str += m["content"]
|
||||
elif isinstance(m["content"], list):
|
||||
for content in m["content"]:
|
||||
if content["type"] == "text":
|
||||
if content["type"] == "text":
|
||||
prompt_str += content["text"]
|
||||
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt_str)
|
||||
)
|
||||
prompt_tokens = len(encoding.encode(prompt_str))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
@ -174,13 +201,14 @@ def completion(
|
|||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "gemini/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
)
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -11,32 +11,47 @@ from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper,
|
|||
from typing import Optional
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
class HuggingfaceError(Exception):
|
||||
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
|
||||
def __init__(
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
if request is not None:
|
||||
self.request = request
|
||||
else:
|
||||
self.request = httpx.Request(method="POST", url="https://api-inference.huggingface.co/models")
|
||||
else:
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api-inference.huggingface.co/models"
|
||||
)
|
||||
if response is not None:
|
||||
self.response = response
|
||||
else:
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
else:
|
||||
self.response = httpx.Response(
|
||||
status_code=status_code, request=self.request
|
||||
)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class HuggingfaceConfig():
|
||||
|
||||
class HuggingfaceConfig:
|
||||
"""
|
||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||
"""
|
||||
|
||||
best_of: Optional[int] = None
|
||||
decoder_input_details: Optional[bool] = None
|
||||
details: Optional[bool] = True # enables returning logprobs + best of
|
||||
details: Optional[bool] = True # enables returning logprobs + best of
|
||||
max_new_tokens: Optional[int] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
return_full_text: Optional[bool] = False # by default don't return the input as part of the output
|
||||
return_full_text: Optional[
|
||||
bool
|
||||
] = False # by default don't return the input as part of the output
|
||||
seed: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
|
@ -46,50 +61,66 @@ class HuggingfaceConfig():
|
|||
typical_p: Optional[float] = None
|
||||
watermark: Optional[bool] = None
|
||||
|
||||
def __init__(self,
|
||||
best_of: Optional[int] = None,
|
||||
decoder_input_details: Optional[bool] = None,
|
||||
details: Optional[bool] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: Optional[bool] = None,
|
||||
seed: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: Optional[bool] = None
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
best_of: Optional[int] = None,
|
||||
decoder_input_details: Optional[bool] = None,
|
||||
details: Optional[bool] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: Optional[bool] = None,
|
||||
seed: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def output_parser(generated_text: str):
|
||||
|
||||
def output_parser(generated_text: str):
|
||||
"""
|
||||
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||
|
||||
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
|
||||
"""
|
||||
chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
||||
for token in chat_template_tokens:
|
||||
for token in chat_template_tokens:
|
||||
if generated_text.strip().startswith(token):
|
||||
generated_text = generated_text.replace(token, "", 1)
|
||||
if generated_text.endswith(token):
|
||||
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||
return generated_text
|
||||
|
||||
|
||||
|
||||
tgi_models_cache = None
|
||||
conv_models_cache = None
|
||||
|
||||
|
||||
def read_tgi_conv_models():
|
||||
try:
|
||||
global tgi_models_cache, conv_models_cache
|
||||
|
@ -101,30 +132,38 @@ def read_tgi_conv_models():
|
|||
tgi_models = set()
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
# Construct the file path relative to the script's directory
|
||||
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_text_generation_models.txt")
|
||||
file_path = os.path.join(
|
||||
script_directory,
|
||||
"huggingface_llms_metadata",
|
||||
"hf_text_generation_models.txt",
|
||||
)
|
||||
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
for line in file:
|
||||
tgi_models.add(line.strip())
|
||||
|
||||
|
||||
# Cache the set for future use
|
||||
tgi_models_cache = tgi_models
|
||||
|
||||
|
||||
# If not, read the file and populate the cache
|
||||
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_conversational_models.txt")
|
||||
file_path = os.path.join(
|
||||
script_directory,
|
||||
"huggingface_llms_metadata",
|
||||
"hf_conversational_models.txt",
|
||||
)
|
||||
conv_models = set()
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
for line in file:
|
||||
conv_models.add(line.strip())
|
||||
# Cache the set for future use
|
||||
conv_models_cache = conv_models
|
||||
conv_models_cache = conv_models
|
||||
return tgi_models, conv_models
|
||||
except:
|
||||
return set(), set()
|
||||
|
||||
|
||||
def get_hf_task_for_model(model):
|
||||
# read text file, cast it to set
|
||||
# read text file, cast it to set
|
||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||
tgi_models, conversational_models = read_tgi_conv_models()
|
||||
if model in tgi_models:
|
||||
|
@ -134,9 +173,10 @@ def get_hf_task_for_model(model):
|
|||
elif "roneneldan/TinyStories" in model:
|
||||
return None
|
||||
else:
|
||||
return "text-generation-inference" # default to tgi
|
||||
return "text-generation-inference" # default to tgi
|
||||
|
||||
class Huggingface(BaseLLM):
|
||||
|
||||
class Huggingface(BaseLLM):
|
||||
_client_session: Optional[httpx.Client] = None
|
||||
_aclient_session: Optional[httpx.AsyncClient] = None
|
||||
|
||||
|
@ -148,65 +188,93 @@ class Huggingface(BaseLLM):
|
|||
"content-type": "application/json",
|
||||
}
|
||||
if api_key and headers is None:
|
||||
default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
|
||||
default_headers[
|
||||
"Authorization"
|
||||
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
|
||||
headers = default_headers
|
||||
elif headers:
|
||||
headers=headers
|
||||
else:
|
||||
headers = headers
|
||||
else:
|
||||
headers = default_headers
|
||||
return headers
|
||||
|
||||
def convert_to_model_response_object(self,
|
||||
completion_response,
|
||||
model_response,
|
||||
task,
|
||||
optional_params,
|
||||
encoding,
|
||||
input_text,
|
||||
model):
|
||||
if task == "conversational":
|
||||
if len(completion_response["generated_text"]) > 0: # type: ignore
|
||||
def convert_to_model_response_object(
|
||||
self,
|
||||
completion_response,
|
||||
model_response,
|
||||
task,
|
||||
optional_params,
|
||||
encoding,
|
||||
input_text,
|
||||
model,
|
||||
):
|
||||
if task == "conversational":
|
||||
if len(completion_response["generated_text"]) > 0: # type: ignore
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["generated_text"] # type: ignore
|
||||
elif task == "text-generation-inference":
|
||||
if (not isinstance(completion_response, list)
|
||||
] = completion_response[
|
||||
"generated_text"
|
||||
] # type: ignore
|
||||
elif task == "text-generation-inference":
|
||||
if (
|
||||
not isinstance(completion_response, list)
|
||||
or not isinstance(completion_response[0], dict)
|
||||
or "generated_text" not in completion_response[0]):
|
||||
raise HuggingfaceError(status_code=422, message=f"response is not in expected format - {completion_response}")
|
||||
or "generated_text" not in completion_response[0]
|
||||
):
|
||||
raise HuggingfaceError(
|
||||
status_code=422,
|
||||
message=f"response is not in expected format - {completion_response}",
|
||||
)
|
||||
|
||||
if len(completion_response[0]["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = output_parser(completion_response[0]["generated_text"])
|
||||
## GETTING LOGPROBS + FINISH REASON
|
||||
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
|
||||
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
|
||||
if len(completion_response[0]["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = output_parser(
|
||||
completion_response[0]["generated_text"]
|
||||
)
|
||||
## GETTING LOGPROBS + FINISH REASON
|
||||
if (
|
||||
"details" in completion_response[0]
|
||||
and "tokens" in completion_response[0]["details"]
|
||||
):
|
||||
model_response.choices[0].finish_reason = completion_response[0][
|
||||
"details"
|
||||
]["finish_reason"]
|
||||
sum_logprob = 0
|
||||
for token in completion_response[0]["details"]["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response["choices"][0]["message"]._logprob = sum_logprob
|
||||
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||
if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]:
|
||||
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||
if (
|
||||
"details" in completion_response[0]
|
||||
and "best_of_sequences" in completion_response[0]["details"]
|
||||
):
|
||||
choices_list = []
|
||||
for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]):
|
||||
for idx, item in enumerate(
|
||||
completion_response[0]["details"]["best_of_sequences"]
|
||||
):
|
||||
sum_logprob = 0
|
||||
for token in item["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
sum_logprob += token["logprob"]
|
||||
if len(item["generated_text"]) > 0:
|
||||
message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob)
|
||||
else:
|
||||
if len(item["generated_text"]) > 0:
|
||||
message_obj = Message(
|
||||
content=output_parser(item["generated_text"]),
|
||||
logprobs=sum_logprob,
|
||||
)
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
|
||||
choice_obj = Choices(
|
||||
finish_reason=item["finish_reason"],
|
||||
index=idx + 1,
|
||||
message=message_obj,
|
||||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"].extend(choices_list)
|
||||
else:
|
||||
if len(completion_response[0]["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = output_parser(completion_response[0]["generated_text"])
|
||||
if len(completion_response[0]["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = output_parser(
|
||||
completion_response[0]["generated_text"]
|
||||
)
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = 0
|
||||
try:
|
||||
|
@ -221,12 +289,14 @@ class Huggingface(BaseLLM):
|
|||
completion_tokens = 0
|
||||
try:
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", "")
|
||||
)
|
||||
) ##[TODO] use the llama2 tokenizer here
|
||||
except:
|
||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||
pass
|
||||
else:
|
||||
else:
|
||||
completion_tokens = 0
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
|
@ -234,13 +304,14 @@ class Huggingface(BaseLLM):
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
model_response._hidden_params["original_response"] = completion_response
|
||||
return model_response
|
||||
|
||||
def completion(self,
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: Optional[str],
|
||||
|
@ -276,9 +347,11 @@ class Huggingface(BaseLLM):
|
|||
completion_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||
|
||||
## Load Config
|
||||
config=litellm.HuggingfaceConfig.get_config()
|
||||
config = litellm.HuggingfaceConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
### MAP INPUT PARAMS
|
||||
|
@ -298,11 +371,11 @@ class Huggingface(BaseLLM):
|
|||
generated_responses.append(message["content"])
|
||||
data = {
|
||||
"inputs": {
|
||||
"text": text,
|
||||
"past_user_inputs": past_user_inputs,
|
||||
"generated_responses": generated_responses
|
||||
"text": text,
|
||||
"past_user_inputs": past_user_inputs,
|
||||
"generated_responses": generated_responses,
|
||||
},
|
||||
"parameters": inference_params
|
||||
"parameters": inference_params,
|
||||
}
|
||||
input_text = "".join(message["content"] for message in messages)
|
||||
elif task == "text-generation-inference":
|
||||
|
@ -311,29 +384,39 @@ class Huggingface(BaseLLM):
|
|||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get(
|
||||
"initial_prompt_value", ""
|
||||
),
|
||||
final_prompt_value=model_prompt_details.get(
|
||||
"final_prompt_value", ""
|
||||
),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
|
||||
"stream": True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
}
|
||||
input_text = prompt
|
||||
else:
|
||||
# Non TGI and Conversational llms
|
||||
# We need this branch, it removes 'details' and 'return_full_text' from params
|
||||
# We need this branch, it removes 'details' and 'return_full_text' from params
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get(
|
||||
"initial_prompt_value", ""
|
||||
),
|
||||
final_prompt_value=model_prompt_details.get(
|
||||
"final_prompt_value", ""
|
||||
),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
|
@ -346,52 +429,68 @@ class Huggingface(BaseLLM):
|
|||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
|
||||
"stream": True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
}
|
||||
input_text = prompt
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input_text,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data, "task": task, "headers": headers, "api_base": completion_url, "acompletion": acompletion},
|
||||
)
|
||||
input=input_text,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"task": task,
|
||||
"headers": headers,
|
||||
"api_base": completion_url,
|
||||
"acompletion": acompletion,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if acompletion is True:
|
||||
### ASYNC STREAMING
|
||||
if acompletion is True:
|
||||
### ASYNC STREAMING
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
|
||||
else:
|
||||
### ASYNC COMPLETION
|
||||
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
|
||||
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
|
||||
### SYNC STREAMING
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = requests.post(
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"]
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"],
|
||||
)
|
||||
return response.iter_lines()
|
||||
### SYNC COMPLETION
|
||||
else:
|
||||
response = requests.post(
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data)
|
||||
completion_url, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
|
||||
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
|
||||
is_streamed = False
|
||||
if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream":
|
||||
is_streamed = False
|
||||
if (
|
||||
response.__dict__["headers"].get("Content-Type", "")
|
||||
== "text/event-stream"
|
||||
):
|
||||
is_streamed = True
|
||||
|
||||
|
||||
# iterate over the complete streamed response, and return the final answer
|
||||
if is_streamed:
|
||||
streamed_response = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="huggingface", logging_obj=logging_obj)
|
||||
streamed_response = CustomStreamWrapper(
|
||||
completion_stream=response.iter_lines(),
|
||||
model=model,
|
||||
custom_llm_provider="huggingface",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
content = ""
|
||||
for chunk in streamed_response:
|
||||
for chunk in streamed_response:
|
||||
content += chunk["choices"][0]["delta"]["content"]
|
||||
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
|
||||
completion_response: List[Dict[str, Any]] = [
|
||||
{"generated_text": content}
|
||||
]
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input_text,
|
||||
|
@ -399,7 +498,7 @@ class Huggingface(BaseLLM):
|
|||
original_response=completion_response,
|
||||
additional_args={"complete_input_dict": data, "task": task},
|
||||
)
|
||||
else:
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input_text,
|
||||
|
@ -410,15 +509,20 @@ class Huggingface(BaseLLM):
|
|||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
if isinstance(completion_response, dict):
|
||||
if isinstance(completion_response, dict):
|
||||
completion_response = [completion_response]
|
||||
except:
|
||||
import traceback
|
||||
|
||||
raise HuggingfaceError(
|
||||
message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}", status_code=response.status_code
|
||||
message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
print_verbose(f"response: {completion_response}")
|
||||
if isinstance(completion_response, dict) and "error" in completion_response:
|
||||
if (
|
||||
isinstance(completion_response, dict)
|
||||
and "error" in completion_response
|
||||
):
|
||||
print_verbose(f"completion error: {completion_response['error']}")
|
||||
print_verbose(f"response.status_code: {response.status_code}")
|
||||
raise HuggingfaceError(
|
||||
|
@ -432,75 +536,98 @@ class Huggingface(BaseLLM):
|
|||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
input_text=input_text,
|
||||
model=model
|
||||
model=model,
|
||||
)
|
||||
except HuggingfaceError as e:
|
||||
except HuggingfaceError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
else:
|
||||
else:
|
||||
import traceback
|
||||
|
||||
raise HuggingfaceError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
async def acompletion(self,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
task: str,
|
||||
encoding: Any,
|
||||
input_text: str,
|
||||
model: str,
|
||||
optional_params: dict):
|
||||
response = None
|
||||
try:
|
||||
async def acompletion(
|
||||
self,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
task: str,
|
||||
encoding: Any,
|
||||
input_text: str,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url=api_base, json=data, headers=headers, timeout=None)
|
||||
response = await client.post(
|
||||
url=api_base, json=data, headers=headers, timeout=None
|
||||
)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
|
||||
|
||||
raise HuggingfaceError(
|
||||
status_code=response.status_code,
|
||||
message=response.text,
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(completion_response=response_json,
|
||||
model_response=model_response,
|
||||
task=task,
|
||||
encoding=encoding,
|
||||
input_text=input_text,
|
||||
model=model,
|
||||
optional_params=optional_params)
|
||||
except Exception as e:
|
||||
if isinstance(e,httpx.TimeoutException):
|
||||
return self.convert_to_model_response_object(
|
||||
completion_response=response_json,
|
||||
model_response=model_response,
|
||||
task=task,
|
||||
encoding=encoding,
|
||||
input_text=input_text,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, httpx.TimeoutException):
|
||||
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
|
||||
elif response is not None and hasattr(response, "text"):
|
||||
raise HuggingfaceError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
|
||||
else:
|
||||
elif response is not None and hasattr(response, "text"):
|
||||
raise HuggingfaceError(
|
||||
status_code=500,
|
||||
message=f"{str(e)}\n\nOriginal Response: {response.text}",
|
||||
)
|
||||
else:
|
||||
raise HuggingfaceError(status_code=500, message=f"{str(e)}")
|
||||
|
||||
async def async_streaming(self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str):
|
||||
async def async_streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = client.stream(
|
||||
"POST",
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers
|
||||
)
|
||||
async with response as r:
|
||||
"POST", url=f"{api_base}", json=data, headers=headers
|
||||
)
|
||||
async with response as r:
|
||||
if r.status_code != 200:
|
||||
raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming")
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
|
||||
raise HuggingfaceError(
|
||||
status_code=r.status_code,
|
||||
message="An error occurred while streaming",
|
||||
)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=r.aiter_lines(),
|
||||
model=model,
|
||||
custom_llm_provider="huggingface",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
def embedding(self,
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
api_key: Optional[str] = None,
|
||||
|
@ -523,65 +650,70 @@ class Huggingface(BaseLLM):
|
|||
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
||||
else:
|
||||
embed_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||
|
||||
if "sentence-transformers" in model:
|
||||
if len(input) == 0:
|
||||
raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences")
|
||||
|
||||
if "sentence-transformers" in model:
|
||||
if len(input) == 0:
|
||||
raise HuggingfaceError(
|
||||
status_code=400,
|
||||
message="sentence transformers requires 2+ sentences",
|
||||
)
|
||||
data = {
|
||||
"inputs": {
|
||||
"source_sentence": input[0],
|
||||
"sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ]
|
||||
"source_sentence": input[0],
|
||||
"sentences": [
|
||||
"That is a happy dog",
|
||||
"That is a very happy person",
|
||||
"Today is a sunny day",
|
||||
],
|
||||
}
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"inputs": input # type: ignore
|
||||
}
|
||||
|
||||
data = {"inputs": input} # type: ignore
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data, "headers": headers, "api_base": embed_url},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
embed_url, headers=headers, data=json.dumps(data)
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": embed_url,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
embeddings = response.json()
|
||||
|
||||
if "error" in embeddings:
|
||||
raise HuggingfaceError(status_code=500, message=embeddings['error'])
|
||||
|
||||
if "error" in embeddings:
|
||||
raise HuggingfaceError(status_code=500, message=embeddings["error"])
|
||||
|
||||
output_data = []
|
||||
if "similarities" in embeddings:
|
||||
if "similarities" in embeddings:
|
||||
for idx, embedding in embeddings["similarities"]:
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
else:
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding, # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
else:
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
if isinstance(embedding, float):
|
||||
if isinstance(embedding, float):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding # flatten list returned from hf
|
||||
"embedding": embedding, # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
elif isinstance(embedding, list) and isinstance(embedding[0], float):
|
||||
|
@ -589,15 +721,17 @@ class Huggingface(BaseLLM):
|
|||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding # flatten list returned from hf
|
||||
"embedding": embedding, # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
else:
|
||||
else:
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding[0][0] # flatten list returned from hf
|
||||
"embedding": embedding[0][
|
||||
0
|
||||
], # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
|
@ -605,13 +739,10 @@ class Huggingface(BaseLLM):
|
|||
model_response["model"] = model
|
||||
input_tokens = 0
|
||||
for text in input:
|
||||
input_tokens+=len(encoding.encode(text))
|
||||
input_tokens += len(encoding.encode(text))
|
||||
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": input_tokens,
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": input_tokens,
|
||||
"total_tokens": input_tokens,
|
||||
}
|
||||
return model_response
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Callable, Optional, List
|
|||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
import litellm
|
||||
|
||||
|
||||
class MaritalkError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
|
@ -15,24 +16,26 @@ class MaritalkError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class MaritTalkConfig():
|
||||
|
||||
class MaritTalkConfig:
|
||||
"""
|
||||
The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters:
|
||||
|
||||
|
||||
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default is 1.
|
||||
|
||||
|
||||
- `model` (string): The model used for conversation. Default is 'maritalk'.
|
||||
|
||||
|
||||
- `do_sample` (boolean): If set to True, the API will generate a response using sampling. Default is True.
|
||||
|
||||
|
||||
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.7.
|
||||
|
||||
|
||||
- `top_p` (number): Selection threshold for token inclusion based on cumulative probability. Default is 0.95.
|
||||
|
||||
|
||||
- `repetition_penalty` (number): Penalty for repetition in the generated conversation. Default is 1.
|
||||
|
||||
|
||||
- `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
do_sample: Optional[bool] = None
|
||||
|
@ -41,27 +44,40 @@ class MaritTalkConfig():
|
|||
repetition_penalty: Optional[float] = None
|
||||
stopping_tokens: Optional[List[str]] = None
|
||||
|
||||
def __init__(self,
|
||||
max_tokens: Optional[int]=None,
|
||||
model: Optional[str] = None,
|
||||
do_sample: Optional[bool] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
stopping_tokens: Optional[List[str]] = None) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
do_sample: Optional[bool] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
stopping_tokens: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
|
@ -71,6 +87,7 @@ def validate_environment(api_key):
|
|||
headers["Authorization"] = f"Key {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -89,9 +106,11 @@ def completion(
|
|||
model = model
|
||||
|
||||
## Load Config
|
||||
config=litellm.MaritTalkConfig.get_config()
|
||||
config = litellm.MaritTalkConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
|
@ -101,24 +120,27 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response.json()
|
||||
|
@ -130,15 +152,17 @@ def completion(
|
|||
else:
|
||||
try:
|
||||
if len(completion_response["answer"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response["answer"]
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["answer"]
|
||||
except Exception as e:
|
||||
raise MaritalkError(message=response.text, status_code=response.status_code)
|
||||
raise MaritalkError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt = "".join(m["content"] for m in messages)
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
@ -148,11 +172,12 @@ def completion(
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
input: list,
|
||||
|
@ -161,4 +186,4 @@ def embedding(
|
|||
model_response=None,
|
||||
encoding=None,
|
||||
):
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Callable, Optional
|
|||
import litellm
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
|
||||
|
||||
class NLPCloudError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
|
@ -15,7 +16,8 @@ class NLPCloudError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class NLPCloudConfig():
|
||||
|
||||
class NLPCloudConfig:
|
||||
"""
|
||||
Reference: https://docs.nlpcloud.com/#generation
|
||||
|
||||
|
@ -43,45 +45,57 @@ class NLPCloudConfig():
|
|||
|
||||
- `num_return_sequences` (int): Optional. The number of independently computed returned sequences.
|
||||
"""
|
||||
max_length: Optional[int]=None
|
||||
length_no_input: Optional[bool]=None
|
||||
end_sequence: Optional[str]=None
|
||||
remove_end_sequence: Optional[bool]=None
|
||||
remove_input: Optional[bool]=None
|
||||
bad_words: Optional[list]=None
|
||||
temperature: Optional[float]=None
|
||||
top_p: Optional[float]=None
|
||||
top_k: Optional[int]=None
|
||||
repetition_penalty: Optional[float]=None
|
||||
num_beams: Optional[int]=None
|
||||
num_return_sequences: Optional[int]=None
|
||||
|
||||
max_length: Optional[int] = None
|
||||
length_no_input: Optional[bool] = None
|
||||
end_sequence: Optional[str] = None
|
||||
remove_end_sequence: Optional[bool] = None
|
||||
remove_input: Optional[bool] = None
|
||||
bad_words: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
num_beams: Optional[int] = None
|
||||
num_return_sequences: Optional[int] = None
|
||||
|
||||
def __init__(self,
|
||||
max_length: Optional[int]=None,
|
||||
length_no_input: Optional[bool]=None,
|
||||
end_sequence: Optional[str]=None,
|
||||
remove_end_sequence: Optional[bool]=None,
|
||||
remove_input: Optional[bool]=None,
|
||||
bad_words: Optional[list]=None,
|
||||
temperature: Optional[float]=None,
|
||||
top_p: Optional[float]=None,
|
||||
top_k: Optional[int]=None,
|
||||
repetition_penalty: Optional[float]=None,
|
||||
num_beams: Optional[int]=None,
|
||||
num_return_sequences: Optional[int]=None) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_length: Optional[int] = None,
|
||||
length_no_input: Optional[bool] = None,
|
||||
end_sequence: Optional[str] = None,
|
||||
remove_end_sequence: Optional[bool] = None,
|
||||
remove_input: Optional[bool] = None,
|
||||
bad_words: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
num_beams: Optional[int] = None,
|
||||
num_return_sequences: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
|
@ -93,6 +107,7 @@ def validate_environment(api_key):
|
|||
headers["Authorization"] = f"Token {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -110,9 +125,11 @@ def completion(
|
|||
headers = validate_environment(api_key)
|
||||
|
||||
## Load Config
|
||||
config = litellm.NLPCloudConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.NLPCloudConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
completion_url_fragment_1 = api_base
|
||||
|
@ -129,24 +146,31 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=text,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url},
|
||||
)
|
||||
input=text,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": completion_url,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
return clean_and_iterate_chunks(response)
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=text,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=text,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
|
@ -161,11 +185,16 @@ def completion(
|
|||
else:
|
||||
try:
|
||||
if len(completion_response["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response["generated_text"]
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["generated_text"]
|
||||
except:
|
||||
raise NLPCloudError(message=json.dumps(completion_response), status_code=response.status_code)
|
||||
raise NLPCloudError(
|
||||
message=json.dumps(completion_response),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = completion_response["nb_input_tokens"]
|
||||
completion_tokens = completion_response["nb_generated_tokens"]
|
||||
|
||||
|
@ -174,7 +203,7 @@ def completion(
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
@ -187,25 +216,27 @@ def completion(
|
|||
# # Perform further processing based on your needs
|
||||
# return cleaned_chunk
|
||||
|
||||
|
||||
# for line in response.iter_lines():
|
||||
# if line:
|
||||
# yield process_chunk(line)
|
||||
def clean_and_iterate_chunks(response):
|
||||
buffer = b''
|
||||
buffer = b""
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffer += chunk
|
||||
while b'\x00' in buffer:
|
||||
buffer = buffer.replace(b'\x00', b'')
|
||||
yield buffer.decode('utf-8')
|
||||
buffer = b''
|
||||
while b"\x00" in buffer:
|
||||
buffer = buffer.replace(b"\x00", b"")
|
||||
yield buffer.decode("utf-8")
|
||||
buffer = b""
|
||||
|
||||
# No more data expected, yield any remaining data in the buffer
|
||||
if buffer:
|
||||
yield buffer.decode('utf-8')
|
||||
yield buffer.decode("utf-8")
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
|
|
|
@ -2,10 +2,11 @@ import requests, types, time
|
|||
import json, uuid
|
||||
import traceback
|
||||
from typing import Optional
|
||||
import litellm
|
||||
import litellm
|
||||
import httpx, aiohttp, asyncio
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
class OllamaError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
|
@ -16,14 +17,15 @@ class OllamaError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class OllamaConfig():
|
||||
|
||||
class OllamaConfig:
|
||||
"""
|
||||
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
|
||||
|
||||
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
|
||||
|
||||
|
||||
- `mirostat` (int): Enable Mirostat sampling for controlling perplexity. Default is 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0. Example usage: mirostat 0
|
||||
|
||||
|
||||
- `mirostat_eta` (float): Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. Default: 0.1. Example usage: mirostat_eta 0.1
|
||||
|
||||
- `mirostat_tau` (float): Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. Default: 5.0. Example usage: mirostat_tau 5.0
|
||||
|
@ -56,102 +58,134 @@ class OllamaConfig():
|
|||
|
||||
- `template` (string): the full prompt or prompt template (overrides what is defined in the Modelfile)
|
||||
"""
|
||||
mirostat: Optional[int]=None
|
||||
mirostat_eta: Optional[float]=None
|
||||
mirostat_tau: Optional[float]=None
|
||||
num_ctx: Optional[int]=None
|
||||
num_gqa: Optional[int]=None
|
||||
num_thread: Optional[int]=None
|
||||
repeat_last_n: Optional[int]=None
|
||||
repeat_penalty: Optional[float]=None
|
||||
temperature: Optional[float]=None
|
||||
stop: Optional[list]=None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
|
||||
tfs_z: Optional[float]=None
|
||||
num_predict: Optional[int]=None
|
||||
top_k: Optional[int]=None
|
||||
top_p: Optional[float]=None
|
||||
system: Optional[str]=None
|
||||
template: Optional[str]=None
|
||||
|
||||
def __init__(self,
|
||||
mirostat: Optional[int]=None,
|
||||
mirostat_eta: Optional[float]=None,
|
||||
mirostat_tau: Optional[float]=None,
|
||||
num_ctx: Optional[int]=None,
|
||||
num_gqa: Optional[int]=None,
|
||||
num_thread: Optional[int]=None,
|
||||
repeat_last_n: Optional[int]=None,
|
||||
repeat_penalty: Optional[float]=None,
|
||||
temperature: Optional[float]=None,
|
||||
stop: Optional[list]=None,
|
||||
tfs_z: Optional[float]=None,
|
||||
num_predict: Optional[int]=None,
|
||||
top_k: Optional[int]=None,
|
||||
top_p: Optional[float]=None,
|
||||
system: Optional[str]=None,
|
||||
template: Optional[str]=None) -> None:
|
||||
mirostat: Optional[int] = None
|
||||
mirostat_eta: Optional[float] = None
|
||||
mirostat_tau: Optional[float] = None
|
||||
num_ctx: Optional[int] = None
|
||||
num_gqa: Optional[int] = None
|
||||
num_thread: Optional[int] = None
|
||||
repeat_last_n: Optional[int] = None
|
||||
repeat_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
stop: Optional[
|
||||
list
|
||||
] = None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
|
||||
tfs_z: Optional[float] = None
|
||||
num_predict: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
system: Optional[str] = None
|
||||
template: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mirostat: Optional[int] = None,
|
||||
mirostat_eta: Optional[float] = None,
|
||||
mirostat_tau: Optional[float] = None,
|
||||
num_ctx: Optional[int] = None,
|
||||
num_gqa: Optional[int] = None,
|
||||
num_thread: Optional[int] = None,
|
||||
repeat_last_n: Optional[int] = None,
|
||||
repeat_penalty: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
stop: Optional[list] = None,
|
||||
tfs_z: Optional[float] = None,
|
||||
num_predict: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
system: Optional[str] = None,
|
||||
template: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
# ollama implementation
|
||||
def get_ollama_response(
|
||||
api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
acompletion: bool = False,
|
||||
model_response=None,
|
||||
encoding=None
|
||||
):
|
||||
api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
acompletion: bool = False,
|
||||
model_response=None,
|
||||
encoding=None,
|
||||
):
|
||||
if api_base.endswith("/api/generate"):
|
||||
url = api_base
|
||||
else:
|
||||
else:
|
||||
url = f"{api_base}/api/generate"
|
||||
|
||||
|
||||
## Load Config
|
||||
config=litellm.OllamaConfig.get_config()
|
||||
config = litellm.OllamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
optional_params["stream"] = optional_params.get("stream", False)
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
data = {"model": model, "prompt": prompt, **optional_params}
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=None,
|
||||
api_key=None,
|
||||
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,},
|
||||
additional_args={
|
||||
"api_base": url,
|
||||
"complete_input_dict": data,
|
||||
"headers": {},
|
||||
"acompletion": acompletion,
|
||||
},
|
||||
)
|
||||
if acompletion is True:
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False) == True:
|
||||
response = ollama_async_streaming(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
|
||||
response = ollama_async_streaming(
|
||||
url=url,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
|
||||
response = ollama_acompletion(
|
||||
url=url,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return response
|
||||
elif optional_params.get("stream", False) == True:
|
||||
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
||||
|
||||
|
||||
response = requests.post(
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
)
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
|
@ -168,52 +202,76 @@ def get_ollama_response(
|
|||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["finish_reason"] = "stop"
|
||||
if optional_params.get("format", "") == "json":
|
||||
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
|
||||
message = litellm.Message(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {"arguments": response_json["response"], "name": ""},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
model_response["choices"][0]["message"] = message
|
||||
else:
|
||||
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + model
|
||||
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
|
||||
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
|
||||
completion_tokens = response_json["eval_count"]
|
||||
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
|
||||
model_response["usage"] = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
return model_response
|
||||
|
||||
|
||||
def ollama_completion_stream(url, data, logging_obj):
|
||||
with httpx.stream(
|
||||
url=url,
|
||||
json=data,
|
||||
method="POST",
|
||||
timeout=litellm.request_timeout
|
||||
) as response:
|
||||
try:
|
||||
url=url, json=data, method="POST", timeout=litellm.request_timeout
|
||||
) as response:
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
|
||||
raise OllamaError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(
|
||||
completion_stream=response.iter_lines(),
|
||||
model=data["model"],
|
||||
custom_llm_provider="ollama",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
||||
try:
|
||||
client = httpx.AsyncClient()
|
||||
async with client.stream(
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
method="POST",
|
||||
timeout=litellm.request_timeout
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.aiter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(
|
||||
completion_stream=response.aiter_lines(),
|
||||
model=data["model"],
|
||||
custom_llm_provider="ollama",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||
data["stream"] = False
|
||||
try:
|
||||
|
@ -224,10 +282,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise OllamaError(status_code=resp.status, message=text)
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data['prompt'],
|
||||
input=data["prompt"],
|
||||
api_key="",
|
||||
original_response=resp.text,
|
||||
additional_args={
|
||||
|
@ -240,37 +298,59 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["finish_reason"] = "stop"
|
||||
if data.get("format", "") == "json":
|
||||
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
|
||||
message = litellm.Message(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {
|
||||
"arguments": response_json["response"],
|
||||
"name": "",
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
model_response["choices"][0]["message"] = message
|
||||
else:
|
||||
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
||||
model_response["choices"][0]["message"]["content"] = response_json[
|
||||
"response"
|
||||
]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + data['model']
|
||||
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
|
||||
model_response["model"] = "ollama/" + data["model"]
|
||||
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
|
||||
completion_tokens = response_json["eval_count"]
|
||||
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
|
||||
model_response["usage"] = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
|
||||
async def ollama_aembeddings(api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
encoding=None):
|
||||
async def ollama_aembeddings(
|
||||
api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
encoding=None,
|
||||
):
|
||||
if api_base.endswith("/api/embeddings"):
|
||||
url = api_base
|
||||
else:
|
||||
else:
|
||||
url = f"{api_base}/api/embeddings"
|
||||
|
||||
|
||||
## Load Config
|
||||
config=litellm.OllamaConfig.get_config()
|
||||
config = litellm.OllamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
|
@ -290,7 +370,7 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
|
|||
if response.status != 200:
|
||||
text = await response.text()
|
||||
raise OllamaError(status_code=response.status, message=text)
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
|
@ -308,20 +388,16 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
|
|||
output_data = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding
|
||||
}
|
||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = output_data
|
||||
model_response["model"] = model
|
||||
|
||||
input_tokens = len(encoding.encode(prompt))
|
||||
input_tokens = len(encoding.encode(prompt))
|
||||
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": input_tokens,
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": input_tokens,
|
||||
"total_tokens": input_tokens,
|
||||
}
|
||||
return model_response
|
||||
return model_response
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Callable, Optional
|
|||
from litellm.utils import ModelResponse, Usage
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
class OobaboogaError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
|
@ -15,6 +16,7 @@ class OobaboogaError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
|
@ -24,6 +26,7 @@ def validate_environment(api_key):
|
|||
headers["Authorization"] = f"Token {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -44,21 +47,24 @@ def completion(
|
|||
completion_url = model
|
||||
elif api_base:
|
||||
completion_url = api_base
|
||||
else:
|
||||
raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')")
|
||||
else:
|
||||
raise OobaboogaError(
|
||||
status_code=404,
|
||||
message="API Base not set. Set one via completion(..,api_base='your-api-url')",
|
||||
)
|
||||
model = model
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
|
||||
completion_url = completion_url + "/api/v1/generate"
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
|
@ -66,30 +72,35 @@ def completion(
|
|||
}
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise OobaboogaError(message=response.text, status_code=response.status_code)
|
||||
raise OobaboogaError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise OobaboogaError(
|
||||
message=completion_response["error"],
|
||||
|
@ -97,14 +108,17 @@ def completion(
|
|||
)
|
||||
else:
|
||||
try:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response['results'][0]['text']
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["results"][0]["text"]
|
||||
except:
|
||||
raise OobaboogaError(message=json.dumps(completion_response), status_code=response.status_code)
|
||||
raise OobaboogaError(
|
||||
message=json.dumps(completion_response),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||
)
|
||||
|
@ -114,11 +128,12 @@ def completion(
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,30 +1,41 @@
|
|||
from typing import List, Dict
|
||||
import types
|
||||
|
||||
class OpenrouterConfig():
|
||||
|
||||
class OpenrouterConfig:
|
||||
"""
|
||||
Reference: https://openrouter.ai/docs#format
|
||||
|
||||
"""
|
||||
|
||||
# OpenRouter-only parameters
|
||||
extra_body: Dict[str, List[str]] = {
|
||||
'transforms': [] # default transforms to []
|
||||
}
|
||||
extra_body: Dict[str, List[str]] = {"transforms": []} # default transforms to []
|
||||
|
||||
|
||||
def __init__(self,
|
||||
transforms: List[str] = [],
|
||||
models: List[str] = [],
|
||||
route: str = '',
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
transforms: List[str] = [],
|
||||
models: List[str] = [],
|
||||
route: str = "",
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
|
|
@ -7,17 +7,22 @@ from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
|
|||
import litellm
|
||||
import sys, httpx
|
||||
|
||||
|
||||
class PalmError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat")
|
||||
self.request = httpx.Request(
|
||||
method="POST",
|
||||
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class PalmConfig():
|
||||
|
||||
class PalmConfig:
|
||||
"""
|
||||
Reference: https://developers.generativeai.google/api/python/google/generativeai/chat
|
||||
|
||||
|
@ -37,35 +42,47 @@ class PalmConfig():
|
|||
|
||||
- `max_output_tokens` (int): Sets the maximum number of tokens to be returned in the output
|
||||
"""
|
||||
context: Optional[str]=None
|
||||
examples: Optional[list]=None
|
||||
temperature: Optional[float]=None
|
||||
candidate_count: Optional[int]=None
|
||||
top_k: Optional[int]=None
|
||||
top_p: Optional[float]=None
|
||||
max_output_tokens: Optional[int]=None
|
||||
|
||||
def __init__(self,
|
||||
context: Optional[str]=None,
|
||||
examples: Optional[list]=None,
|
||||
temperature: Optional[float]=None,
|
||||
candidate_count: Optional[int]=None,
|
||||
top_k: Optional[int]=None,
|
||||
top_p: Optional[float]=None,
|
||||
max_output_tokens: Optional[int]=None) -> None:
|
||||
|
||||
context: Optional[str] = None
|
||||
examples: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
candidate_count: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Optional[str] = None,
|
||||
examples: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
candidate_count: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def completion(
|
||||
|
@ -83,41 +100,43 @@ def completion(
|
|||
try:
|
||||
import google.generativeai as palm
|
||||
except:
|
||||
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai")
|
||||
raise Exception(
|
||||
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||
)
|
||||
palm.configure(api_key=api_key)
|
||||
|
||||
model = model
|
||||
|
||||
|
||||
## Load Config
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py
|
||||
config = litellm.PalmConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params.pop(
|
||||
"stream", None
|
||||
) # palm does not support streaming, so we handle this by fake streaming in main.py
|
||||
config = litellm.PalmConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": {"inference_params": inference_params}},
|
||||
)
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": {"inference_params": inference_params}},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
try:
|
||||
response = palm.generate_text(prompt=prompt, **inference_params)
|
||||
except Exception as e:
|
||||
raise PalmError(
|
||||
|
@ -127,11 +146,11 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": {}},
|
||||
)
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": {}},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response
|
||||
|
@ -142,22 +161,25 @@ def completion(
|
|||
message_obj = Message(content=item["output"])
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(index=idx+1, message=message_obj)
|
||||
choice_obj = Choices(index=idx + 1, message=message_obj)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise PalmError(message=traceback.format_exc(), status_code=response.status_code)
|
||||
|
||||
try:
|
||||
raise PalmError(
|
||||
message=traceback.format_exc(), status_code=response.status_code
|
||||
)
|
||||
|
||||
try:
|
||||
completion_response = model_response["choices"][0]["message"].get("content")
|
||||
except:
|
||||
raise PalmError(status_code=400, message=f"No response received. Original response - {response}")
|
||||
raise PalmError(
|
||||
status_code=400,
|
||||
message=f"No response received. Original response - {response}",
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
@ -165,13 +187,14 @@ def completion(
|
|||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "palm/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
)
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -8,6 +8,7 @@ import litellm
|
|||
from litellm.utils import ModelResponse, Usage
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
class PetalsError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
|
@ -16,7 +17,8 @@ class PetalsError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class PetalsConfig():
|
||||
|
||||
class PetalsConfig:
|
||||
"""
|
||||
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
|
||||
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
|
||||
|
@ -30,45 +32,64 @@ class PetalsConfig():
|
|||
- `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
|
||||
|
||||
- `temperature` (float, optional): This value sets the temperature for sampling.
|
||||
|
||||
|
||||
- `top_k` (integer, optional): This value sets the limit for top-k sampling.
|
||||
|
||||
|
||||
- `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
|
||||
|
||||
|
||||
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
|
||||
"""
|
||||
max_length: Optional[int]=None
|
||||
max_new_tokens: Optional[int]=litellm.max_tokens # petals requires max tokens to be set
|
||||
do_sample: Optional[bool]=None
|
||||
temperature: Optional[float]=None
|
||||
top_k: Optional[int]=None
|
||||
top_p: Optional[float]=None
|
||||
repetition_penalty: Optional[float]=None
|
||||
|
||||
def __init__(self,
|
||||
max_length: Optional[int]=None,
|
||||
max_new_tokens: Optional[int]=litellm.max_tokens, # petals requires max tokens to be set
|
||||
do_sample: Optional[bool]=None,
|
||||
temperature: Optional[float]=None,
|
||||
top_k: Optional[int]=None,
|
||||
top_p: Optional[float]=None,
|
||||
repetition_penalty: Optional[float]=None) -> None:
|
||||
max_length: Optional[int] = None
|
||||
max_new_tokens: Optional[
|
||||
int
|
||||
] = litellm.max_tokens # petals requires max tokens to be set
|
||||
do_sample: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_length: Optional[int] = None,
|
||||
max_new_tokens: Optional[
|
||||
int
|
||||
] = litellm.max_tokens, # petals requires max tokens to be set
|
||||
do_sample: Optional[bool] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: Optional[str],
|
||||
api_base: Optional[str],
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
|
@ -80,96 +101,97 @@ def completion(
|
|||
):
|
||||
## Load Config
|
||||
config = litellm.PetalsConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
if model in litellm.custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = litellm.custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
if api_base:
|
||||
if api_base:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": optional_params, "api_base": api_base},
|
||||
)
|
||||
data = {
|
||||
"model": model,
|
||||
"inputs": prompt,
|
||||
**optional_params
|
||||
}
|
||||
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
data = {"model": model, "inputs": prompt, **optional_params}
|
||||
|
||||
## COMPLETION CALL
|
||||
response = requests.post(api_base, data=data)
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": optional_params},
|
||||
)
|
||||
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": optional_params},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
output_text = response.json()["outputs"]
|
||||
except Exception as e:
|
||||
PetalsError(status_code=response.status_code, message=str(e))
|
||||
|
||||
else:
|
||||
else:
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from petals import AutoDistributedModelForCausalLM # type: ignore
|
||||
from petals import AutoDistributedModelForCausalLM # type: ignore
|
||||
except:
|
||||
raise Exception(
|
||||
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
|
||||
)
|
||||
|
||||
|
||||
model = model
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, add_bos_token=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model, use_fast=False, add_bos_token=False
|
||||
)
|
||||
model_obj = AutoDistributedModelForCausalLM.from_pretrained(model)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": optional_params},
|
||||
)
|
||||
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": optional_params},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
inputs = tokenizer(prompt, return_tensors="pt")["input_ids"]
|
||||
|
||||
|
||||
# optional params: max_new_tokens=1,temperature=0.9, top_p=0.6
|
||||
outputs = model_obj.generate(inputs, **optional_params)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=outputs,
|
||||
additional_args={"complete_input_dict": optional_params},
|
||||
)
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=outputs,
|
||||
additional_args={"complete_input_dict": optional_params},
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
output_text = tokenizer.decode(outputs[0])
|
||||
|
||||
|
||||
if len(output_text) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = output_text
|
||||
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content"))
|
||||
)
|
||||
|
@ -177,13 +199,14 @@ def completion(
|
|||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
)
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -4,11 +4,13 @@ import json
|
|||
from jinja2 import Template, exceptions, Environment, meta
|
||||
from typing import Optional, Any
|
||||
|
||||
|
||||
def default_pt(messages):
|
||||
return " ".join(message["content"] for message in messages)
|
||||
|
||||
# alpaca prompt template - for models like mythomax, etc.
|
||||
def alpaca_pt(messages):
|
||||
|
||||
# alpaca prompt template - for models like mythomax, etc.
|
||||
def alpaca_pt(messages):
|
||||
prompt = custom_prompt(
|
||||
role_dict={
|
||||
"system": {
|
||||
|
@ -19,59 +21,56 @@ def alpaca_pt(messages):
|
|||
"pre_message": "### Instruction:\n",
|
||||
"post_message": "\n\n",
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "### Response:\n",
|
||||
"post_message": "\n\n"
|
||||
}
|
||||
"assistant": {"pre_message": "### Response:\n", "post_message": "\n\n"},
|
||||
},
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
messages=messages
|
||||
messages=messages,
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
# Llama2 prompt template
|
||||
def llama_2_chat_pt(messages):
|
||||
prompt = custom_prompt(
|
||||
role_dict={
|
||||
"system": {
|
||||
"pre_message": "[INST] <<SYS>>\n",
|
||||
"post_message": "\n<</SYS>>\n [/INST]\n"
|
||||
"post_message": "\n<</SYS>>\n [/INST]\n",
|
||||
},
|
||||
"user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348
|
||||
"user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348
|
||||
"pre_message": "[INST] ",
|
||||
"post_message": " [/INST]\n"
|
||||
},
|
||||
"post_message": " [/INST]\n",
|
||||
},
|
||||
"assistant": {
|
||||
"post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama
|
||||
}
|
||||
"post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama
|
||||
},
|
||||
},
|
||||
messages=messages,
|
||||
bos_token="<s>",
|
||||
eos_token="</s>"
|
||||
eos_token="</s>",
|
||||
)
|
||||
return prompt
|
||||
|
||||
def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||
|
||||
if "instruct" in model:
|
||||
|
||||
def ollama_pt(
|
||||
model, messages
|
||||
): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||
if "instruct" in model:
|
||||
prompt = custom_prompt(
|
||||
role_dict={
|
||||
"system": {
|
||||
"pre_message": "### System:\n",
|
||||
"post_message": "\n"
|
||||
},
|
||||
"system": {"pre_message": "### System:\n", "post_message": "\n"},
|
||||
"user": {
|
||||
"pre_message": "### User:\n",
|
||||
"post_message": "\n",
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "### Response:\n",
|
||||
"post_message": "\n",
|
||||
}
|
||||
},
|
||||
},
|
||||
final_prompt_value="### Response:",
|
||||
messages=messages
|
||||
messages=messages,
|
||||
)
|
||||
elif "llava" in model:
|
||||
prompt = ""
|
||||
|
@ -88,36 +87,31 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
|
|||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
images.append(image_url)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"images": images
|
||||
}
|
||||
else:
|
||||
prompt = "".join(m["content"] if isinstance(m['content'], str) is str else "".join(m['content']) for m in messages)
|
||||
return {"prompt": prompt, "images": images}
|
||||
else:
|
||||
prompt = "".join(
|
||||
m["content"]
|
||||
if isinstance(m["content"], str) is str
|
||||
else "".join(m["content"])
|
||||
for m in messages
|
||||
)
|
||||
return prompt
|
||||
|
||||
def mistral_instruct_pt(messages):
|
||||
|
||||
def mistral_instruct_pt(messages):
|
||||
prompt = custom_prompt(
|
||||
initial_prompt_value="<s>",
|
||||
role_dict={
|
||||
"system": {
|
||||
"pre_message": "[INST]",
|
||||
"post_message": "[/INST]"
|
||||
},
|
||||
"user": {
|
||||
"pre_message": "[INST]",
|
||||
"post_message": "[/INST]"
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "[INST]",
|
||||
"post_message": "[/INST]"
|
||||
}
|
||||
"system": {"pre_message": "[INST]", "post_message": "[/INST]"},
|
||||
"user": {"pre_message": "[INST]", "post_message": "[/INST]"},
|
||||
"assistant": {"pre_message": "[INST]", "post_message": "[/INST]"},
|
||||
},
|
||||
final_prompt_value="</s>",
|
||||
messages=messages
|
||||
messages=messages,
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
||||
def falcon_instruct_pt(messages):
|
||||
prompt = ""
|
||||
|
@ -125,11 +119,16 @@ def falcon_instruct_pt(messages):
|
|||
if message["role"] == "system":
|
||||
prompt += message["content"]
|
||||
else:
|
||||
prompt += message['role']+":"+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n")
|
||||
prompt += (
|
||||
message["role"]
|
||||
+ ":"
|
||||
+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n")
|
||||
)
|
||||
prompt += "\n\n"
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def falcon_chat_pt(messages):
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
|
@ -142,6 +141,7 @@ def falcon_chat_pt(messages):
|
|||
|
||||
return prompt
|
||||
|
||||
|
||||
# MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
||||
def mpt_chat_pt(messages):
|
||||
prompt = ""
|
||||
|
@ -154,18 +154,20 @@ def mpt_chat_pt(messages):
|
|||
prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n"
|
||||
return prompt
|
||||
|
||||
|
||||
# WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format
|
||||
def wizardcoder_pt(messages):
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
prompt += message["content"] + "\n\n"
|
||||
elif message["role"] == "user": # map to 'Instruction'
|
||||
elif message["role"] == "user": # map to 'Instruction'
|
||||
prompt += "### Instruction:\n" + message["content"] + "\n\n"
|
||||
elif message["role"] == "assistant": # map to 'Response'
|
||||
elif message["role"] == "assistant": # map to 'Response'
|
||||
prompt += "### Response:\n" + message["content"] + "\n\n"
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
# Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model
|
||||
def phind_codellama_pt(messages):
|
||||
prompt = ""
|
||||
|
@ -178,13 +180,17 @@ def phind_codellama_pt(messages):
|
|||
prompt += "### Assistant\n" + message["content"] + "\n\n"
|
||||
return prompt
|
||||
|
||||
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
|
||||
|
||||
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = None):
|
||||
## get the tokenizer config from huggingface
|
||||
bos_token = ""
|
||||
eos_token = ""
|
||||
if chat_template is None:
|
||||
if chat_template is None:
|
||||
|
||||
def _get_tokenizer_config(hf_model_name):
|
||||
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
||||
url = (
|
||||
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
||||
)
|
||||
# Make a GET request to fetch the JSON data
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
|
@ -193,10 +199,14 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
|
|||
return {"status": "success", "tokenizer": tokenizer_config}
|
||||
else:
|
||||
return {"status": "failure"}
|
||||
|
||||
tokenizer_config = _get_tokenizer_config(model)
|
||||
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
|
||||
if (
|
||||
tokenizer_config["status"] == "failure"
|
||||
or "chat_template" not in tokenizer_config["tokenizer"]
|
||||
):
|
||||
raise Exception("No chat template found")
|
||||
## read the bos token, eos token and chat template from the json
|
||||
## read the bos token, eos token and chat template from the json
|
||||
tokenizer_config = tokenizer_config["tokenizer"]
|
||||
bos_token = tokenizer_config["bos_token"]
|
||||
eos_token = tokenizer_config["eos_token"]
|
||||
|
@ -204,10 +214,10 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
|
|||
|
||||
def raise_exception(message):
|
||||
raise Exception(f"Error message - {message}")
|
||||
|
||||
|
||||
# Create a template object from the template text
|
||||
env = Environment()
|
||||
env.globals['raise_exception'] = raise_exception
|
||||
env.globals["raise_exception"] = raise_exception
|
||||
try:
|
||||
template = env.from_string(chat_template)
|
||||
except Exception as e:
|
||||
|
@ -216,137 +226,167 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
|
|||
def _is_system_in_template():
|
||||
try:
|
||||
# Try rendering the template with a system message
|
||||
response = template.render(messages=[{"role": "system", "content": "test"}], eos_token= "<eos>", bos_token= "<bos>")
|
||||
response = template.render(
|
||||
messages=[{"role": "system", "content": "test"}],
|
||||
eos_token="<eos>",
|
||||
bos_token="<bos>",
|
||||
)
|
||||
return True
|
||||
|
||||
# This will be raised if Jinja attempts to render the system message and it can't
|
||||
except:
|
||||
return False
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
# Render the template with the provided values
|
||||
if _is_system_in_template():
|
||||
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=messages)
|
||||
else:
|
||||
if _is_system_in_template():
|
||||
rendered_text = template.render(
|
||||
bos_token=bos_token, eos_token=eos_token, messages=messages
|
||||
)
|
||||
else:
|
||||
# treat a system message as a user message, if system not in template
|
||||
try:
|
||||
reformatted_messages = []
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
reformatted_messages.append({"role": "user", "content": message["content"]})
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
reformatted_messages.append(
|
||||
{"role": "user", "content": message["content"]}
|
||||
)
|
||||
else:
|
||||
reformatted_messages.append(message)
|
||||
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=reformatted_messages)
|
||||
rendered_text = template.render(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
messages=reformatted_messages,
|
||||
)
|
||||
except Exception as e:
|
||||
if "Conversation roles must alternate user/assistant" in str(e):
|
||||
if "Conversation roles must alternate user/assistant" in str(e):
|
||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
||||
new_messages = []
|
||||
for i in range(len(reformatted_messages)-1):
|
||||
for i in range(len(reformatted_messages) - 1):
|
||||
new_messages.append(reformatted_messages[i])
|
||||
if reformatted_messages[i]["role"] == reformatted_messages[i+1]["role"]:
|
||||
if (
|
||||
reformatted_messages[i]["role"]
|
||||
== reformatted_messages[i + 1]["role"]
|
||||
):
|
||||
if reformatted_messages[i]["role"] == "user":
|
||||
new_messages.append({"role": "assistant", "content": ""})
|
||||
new_messages.append(
|
||||
{"role": "assistant", "content": ""}
|
||||
)
|
||||
else:
|
||||
new_messages.append({"role": "user", "content": ""})
|
||||
new_messages.append(reformatted_messages[-1])
|
||||
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages)
|
||||
rendered_text = template.render(
|
||||
bos_token=bos_token, eos_token=eos_token, messages=new_messages
|
||||
)
|
||||
return rendered_text
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise Exception(f"Error rendering template - {str(e)}")
|
||||
|
||||
# Anthropic template
|
||||
def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
|
||||
|
||||
# Anthropic template
|
||||
def claude_2_1_pt(
|
||||
messages: list,
|
||||
): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
|
||||
"""
|
||||
Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human:
|
||||
Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human:
|
||||
- you can't just pass a system message
|
||||
- you can't pass a system message and follow that with an assistant message
|
||||
- you can't pass a system message and follow that with an assistant message
|
||||
if system message is passed in, you can only do system, human, assistant or system, human
|
||||
|
||||
if a system message is passed in and followed by an assistant message, insert a blank human message between them.
|
||||
if a system message is passed in and followed by an assistant message, insert a blank human message between them.
|
||||
"""
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
prompt = ""
|
||||
for idx, message in enumerate(messages):
|
||||
|
||||
prompt = ""
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||
)
|
||||
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||
elif message["role"] == "system":
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
prompt += f"{message['content']}"
|
||||
elif message["role"] == "assistant":
|
||||
if idx > 0 and messages[idx - 1]["role"] == "system":
|
||||
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message
|
||||
prompt += (
|
||||
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
|
||||
)
|
||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
|
||||
if idx > 0 and messages[idx - 1]["role"] == "system":
|
||||
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message
|
||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
|
||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
|
||||
return prompt
|
||||
|
||||
### TOGETHER AI
|
||||
|
||||
### TOGETHER AI
|
||||
|
||||
|
||||
def get_model_info(token, model):
|
||||
try:
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}'
|
||||
}
|
||||
response = requests.get('https://api.together.xyz/models/info', headers=headers)
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = requests.get("https://api.together.xyz/models/info", headers=headers)
|
||||
if response.status_code == 200:
|
||||
model_info = response.json()
|
||||
for m in model_info:
|
||||
if m["name"].lower().strip() == model.strip():
|
||||
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
|
||||
for m in model_info:
|
||||
if m["name"].lower().strip() == model.strip():
|
||||
return m["config"].get("prompt_format", None), m["config"].get(
|
||||
"chat_template", None
|
||||
)
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
except Exception as e: # safely fail a prompt template request
|
||||
except Exception as e: # safely fail a prompt template request
|
||||
return None, None
|
||||
|
||||
|
||||
def format_prompt_togetherai(messages, prompt_format, chat_template):
|
||||
if prompt_format is None:
|
||||
return default_pt(messages)
|
||||
|
||||
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
|
||||
|
||||
human_prompt, assistant_prompt = prompt_format.split("{prompt}")
|
||||
|
||||
if chat_template is not None:
|
||||
prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template)
|
||||
elif prompt_format is not None:
|
||||
prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt)
|
||||
else:
|
||||
prompt = hf_chat_template(
|
||||
model=None, messages=messages, chat_template=chat_template
|
||||
)
|
||||
elif prompt_format is not None:
|
||||
prompt = custom_prompt(
|
||||
role_dict={},
|
||||
messages=messages,
|
||||
initial_prompt_value=human_prompt,
|
||||
final_prompt_value=assistant_prompt,
|
||||
)
|
||||
else:
|
||||
prompt = default_pt(messages)
|
||||
return prompt
|
||||
return prompt
|
||||
|
||||
|
||||
###
|
||||
|
||||
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
|
||||
|
||||
def anthropic_pt(
|
||||
messages: list,
|
||||
): # format - https://docs.anthropic.com/claude/reference/complete_post
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
prompt = ""
|
||||
for idx, message in enumerate(messages): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
|
||||
|
||||
prompt = ""
|
||||
for idx, message in enumerate(
|
||||
messages
|
||||
): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||
)
|
||||
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||
elif message["role"] == "system":
|
||||
prompt += (
|
||||
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
|
||||
)
|
||||
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
|
||||
else:
|
||||
prompt += (
|
||||
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
|
||||
)
|
||||
if idx == 0 and message["role"] == "assistant": # ensure the prompt always starts with `\n\nHuman: `
|
||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
|
||||
if (
|
||||
idx == 0 and message["role"] == "assistant"
|
||||
): # ensure the prompt always starts with `\n\nHuman: `
|
||||
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt
|
||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
|
||||
return prompt
|
||||
return prompt
|
||||
|
||||
def gemini_text_image_pt(messages: list):
|
||||
|
||||
def gemini_text_image_pt(messages: list):
|
||||
"""
|
||||
{
|
||||
"contents":[
|
||||
|
@ -367,13 +407,15 @@ def gemini_text_image_pt(messages: list):
|
|||
try:
|
||||
import google.generativeai as genai
|
||||
except:
|
||||
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai")
|
||||
|
||||
raise Exception(
|
||||
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||
)
|
||||
|
||||
prompt = ""
|
||||
images = []
|
||||
for message in messages:
|
||||
images = []
|
||||
for message in messages:
|
||||
if isinstance(message["content"], str):
|
||||
prompt += message["content"]
|
||||
prompt += message["content"]
|
||||
elif isinstance(message["content"], list):
|
||||
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
||||
for element in message["content"]:
|
||||
|
@ -383,45 +425,63 @@ def gemini_text_image_pt(messages: list):
|
|||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
images.append(image_url)
|
||||
|
||||
|
||||
content = [prompt] + images
|
||||
return content
|
||||
|
||||
# Function call template
|
||||
|
||||
# Function call template
|
||||
def function_call_prompt(messages: list, functions: list):
|
||||
function_prompt = "Produce JSON OUTPUT ONLY! The following functions are available to you:"
|
||||
for function in functions:
|
||||
function_prompt = (
|
||||
"Produce JSON OUTPUT ONLY! The following functions are available to you:"
|
||||
)
|
||||
for function in functions:
|
||||
function_prompt += f"""\n{function}\n"""
|
||||
|
||||
|
||||
function_added_to_prompt = False
|
||||
for message in messages:
|
||||
if "system" in message["role"]:
|
||||
message['content'] += f"""{function_prompt}"""
|
||||
for message in messages:
|
||||
if "system" in message["role"]:
|
||||
message["content"] += f"""{function_prompt}"""
|
||||
function_added_to_prompt = True
|
||||
|
||||
if function_added_to_prompt == False:
|
||||
messages.append({'role': 'system', 'content': f"""{function_prompt}"""})
|
||||
|
||||
if function_added_to_prompt == False:
|
||||
messages.append({"role": "system", "content": f"""{function_prompt}"""})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
# Custom prompt template
|
||||
def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str="", bos_token: str="", eos_token: str=""):
|
||||
def custom_prompt(
|
||||
role_dict: dict,
|
||||
messages: list,
|
||||
initial_prompt_value: str = "",
|
||||
final_prompt_value: str = "",
|
||||
bos_token: str = "",
|
||||
eos_token: str = "",
|
||||
):
|
||||
prompt = bos_token + initial_prompt_value
|
||||
bos_open = True
|
||||
## a bos token is at the start of a system / human message
|
||||
## an eos token is at the end of the assistant response to the message
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
|
||||
|
||||
if role in ["system", "human"] and not bos_open:
|
||||
prompt += bos_token
|
||||
bos_open = True
|
||||
|
||||
pre_message_str = role_dict[role]["pre_message"] if role in role_dict and "pre_message" in role_dict[role] else ""
|
||||
post_message_str = role_dict[role]["post_message"] if role in role_dict and "post_message" in role_dict[role] else ""
|
||||
|
||||
pre_message_str = (
|
||||
role_dict[role]["pre_message"]
|
||||
if role in role_dict and "pre_message" in role_dict[role]
|
||||
else ""
|
||||
)
|
||||
post_message_str = (
|
||||
role_dict[role]["post_message"]
|
||||
if role in role_dict and "post_message" in role_dict[role]
|
||||
else ""
|
||||
)
|
||||
prompt += pre_message_str + message["content"] + post_message_str
|
||||
|
||||
|
||||
if role == "assistant":
|
||||
prompt += eos_token
|
||||
bos_open = False
|
||||
|
@ -429,25 +489,35 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
|
|||
prompt += final_prompt_value
|
||||
return prompt
|
||||
|
||||
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None):
|
||||
|
||||
def prompt_factory(
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
original_model_name = model
|
||||
model = model.lower()
|
||||
if custom_llm_provider == "ollama":
|
||||
if custom_llm_provider == "ollama":
|
||||
return ollama_pt(model=model, messages=messages)
|
||||
elif custom_llm_provider == "anthropic":
|
||||
if "claude-2.1" in model:
|
||||
if "claude-2.1" in model:
|
||||
return claude_2_1_pt(messages=messages)
|
||||
else:
|
||||
else:
|
||||
return anthropic_pt(messages=messages)
|
||||
elif custom_llm_provider == "together_ai":
|
||||
elif custom_llm_provider == "together_ai":
|
||||
prompt_format, chat_template = get_model_info(token=api_key, model=model)
|
||||
return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template)
|
||||
elif custom_llm_provider == "gemini":
|
||||
return format_prompt_togetherai(
|
||||
messages=messages, prompt_format=prompt_format, chat_template=chat_template
|
||||
)
|
||||
elif custom_llm_provider == "gemini":
|
||||
return gemini_text_image_pt(messages=messages)
|
||||
try:
|
||||
if "meta-llama/llama-2" in model and "chat" in model:
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
|
||||
elif (
|
||||
"tiiuae/falcon" in model
|
||||
): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
|
||||
if model == "tiiuae/falcon-180B-chat":
|
||||
return falcon_chat_pt(messages=messages)
|
||||
elif "instruct" in model:
|
||||
|
@ -457,17 +527,26 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
|||
return mpt_chat_pt(messages=messages)
|
||||
elif "codellama/codellama" in model or "togethercomputer/codellama" in model:
|
||||
if "instruct" in model:
|
||||
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
|
||||
return llama_2_chat_pt(
|
||||
messages=messages
|
||||
) # https://huggingface.co/blog/codellama#conversational-instructions
|
||||
elif "wizardlm/wizardcoder" in model:
|
||||
return wizardcoder_pt(messages=messages)
|
||||
elif "phind/phind-codellama" in model:
|
||||
return phind_codellama_pt(messages=messages)
|
||||
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
|
||||
elif "togethercomputer/llama-2" in model and (
|
||||
"instruct" in model or "chat" in model
|
||||
):
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
elif model in ["gryphe/mythomax-l2-13b", "gryphe/mythomix-l2-13b", "gryphe/mythologic-l2-13b"]:
|
||||
return alpaca_pt(messages=messages)
|
||||
else:
|
||||
elif model in [
|
||||
"gryphe/mythomax-l2-13b",
|
||||
"gryphe/mythomix-l2-13b",
|
||||
"gryphe/mythologic-l2-13b",
|
||||
]:
|
||||
return alpaca_pt(messages=messages)
|
||||
else:
|
||||
return hf_chat_template(original_model_name, messages)
|
||||
except Exception as e:
|
||||
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
||||
|
||||
return default_pt(
|
||||
messages=messages
|
||||
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
||||
|
|
|
@ -4,81 +4,100 @@ import requests
|
|||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
import litellm
|
||||
import litellm
|
||||
import httpx
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
class ReplicateError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.replicate.com/v1/deployments")
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.replicate.com/v1/deployments"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class ReplicateConfig():
|
||||
|
||||
class ReplicateConfig:
|
||||
"""
|
||||
Reference: https://replicate.com/meta/llama-2-70b-chat/api
|
||||
- `prompt` (string): The prompt to send to the model.
|
||||
|
||||
|
||||
- `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`.
|
||||
|
||||
|
||||
- `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`.
|
||||
|
||||
|
||||
- `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`.
|
||||
|
||||
|
||||
- `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`.
|
||||
|
||||
|
||||
- `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`.
|
||||
|
||||
|
||||
- `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`.
|
||||
|
||||
|
||||
- `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'.
|
||||
|
||||
|
||||
- `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed.
|
||||
|
||||
|
||||
- `debug` (boolean): If set to `True`, it provides debugging output in logs.
|
||||
|
||||
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
|
||||
"""
|
||||
system_prompt: Optional[str]=None
|
||||
max_new_tokens: Optional[int]=None
|
||||
min_new_tokens: Optional[int]=None
|
||||
temperature: Optional[int]=None
|
||||
top_p: Optional[int]=None
|
||||
top_k: Optional[int]=None
|
||||
stop_sequences: Optional[str]=None
|
||||
seed: Optional[int]=None
|
||||
debug: Optional[bool]=None
|
||||
|
||||
def __init__(self,
|
||||
system_prompt: Optional[str]=None,
|
||||
max_new_tokens: Optional[int]=None,
|
||||
min_new_tokens: Optional[int]=None,
|
||||
temperature: Optional[int]=None,
|
||||
top_p: Optional[int]=None,
|
||||
top_k: Optional[int]=None,
|
||||
stop_sequences: Optional[str]=None,
|
||||
seed: Optional[int]=None,
|
||||
debug: Optional[bool]=None) -> None:
|
||||
system_prompt: Optional[str] = None
|
||||
max_new_tokens: Optional[int] = None
|
||||
min_new_tokens: Optional[int] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
stop_sequences: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
debug: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
stop_sequences: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
debug: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
# Function to start a prediction and get the prediction URL
|
||||
def start_prediction(version_id, input_data, api_token, api_base, logging_obj, print_verbose):
|
||||
def start_prediction(
|
||||
version_id, input_data, api_token, api_base, logging_obj, print_verbose
|
||||
):
|
||||
base_url = api_base
|
||||
if "deployments" in version_id:
|
||||
print_verbose("\nLiteLLM: Request to custom replicate deployment")
|
||||
|
@ -88,7 +107,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
|
|||
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
initial_prediction_data = {
|
||||
|
@ -98,24 +117,33 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
|
|||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input_data["prompt"],
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers, "api_base": base_url},
|
||||
input=input_data["prompt"],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": initial_prediction_data,
|
||||
"headers": headers,
|
||||
"api_base": base_url,
|
||||
},
|
||||
)
|
||||
|
||||
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers)
|
||||
response = requests.post(
|
||||
f"{base_url}/predictions", json=initial_prediction_data, headers=headers
|
||||
)
|
||||
if response.status_code == 201:
|
||||
response_data = response.json()
|
||||
return response_data.get("urls", {}).get("get")
|
||||
else:
|
||||
raise ReplicateError(response.status_code, f"Failed to start prediction {response.text}")
|
||||
raise ReplicateError(
|
||||
response.status_code, f"Failed to start prediction {response.text}"
|
||||
)
|
||||
|
||||
|
||||
# Function to handle prediction response (non-streaming)
|
||||
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
||||
output_string = ""
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
status = ""
|
||||
|
@ -127,18 +155,22 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
|
|||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if "output" in response_data:
|
||||
output_string = "".join(response_data['output'])
|
||||
output_string = "".join(response_data["output"])
|
||||
print_verbose(f"Non-streamed output:{output_string}")
|
||||
status = response_data.get('status', None)
|
||||
status = response_data.get("status", None)
|
||||
logs = response_data.get("logs", "")
|
||||
if status == "failed":
|
||||
replicate_error = response_data.get("error", "")
|
||||
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}, \nReplicate logs:{logs}")
|
||||
raise ReplicateError(
|
||||
status_code=400,
|
||||
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
|
||||
)
|
||||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose("Replicate: Failed to fetch prediction status and output.")
|
||||
return output_string, logs
|
||||
|
||||
|
||||
# Function to handle prediction response (streaming)
|
||||
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
|
||||
previous_output = ""
|
||||
|
@ -146,30 +178,34 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
|
|||
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
status = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
time.sleep(0.5) # prevent being rate limited by replicate
|
||||
time.sleep(0.5) # prevent being rate limited by replicate
|
||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||
response = requests.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
status = response_data['status']
|
||||
status = response_data["status"]
|
||||
if "output" in response_data:
|
||||
output_string = "".join(response_data['output'])
|
||||
new_output = output_string[len(previous_output):]
|
||||
output_string = "".join(response_data["output"])
|
||||
new_output = output_string[len(previous_output) :]
|
||||
print_verbose(f"New chunk: {new_output}")
|
||||
yield {"output": new_output, "status": status}
|
||||
previous_output = output_string
|
||||
status = response_data['status']
|
||||
status = response_data["status"]
|
||||
if status == "failed":
|
||||
replicate_error = response_data.get("error", "")
|
||||
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}")
|
||||
raise ReplicateError(
|
||||
status_code=400, message=f"Error: {replicate_error}"
|
||||
)
|
||||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}")
|
||||
|
||||
print_verbose(
|
||||
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
|
||||
)
|
||||
|
||||
|
||||
# Function to extract version ID from model string
|
||||
def model_to_version_id(model):
|
||||
|
@ -178,11 +214,12 @@ def model_to_version_id(model):
|
|||
return split_model[1]
|
||||
return model
|
||||
|
||||
|
||||
# Main function for prediction completion
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
|
@ -196,35 +233,37 @@ def completion(
|
|||
# Start a prediction and get the prediction URL
|
||||
version_id = model_to_version_id(model)
|
||||
## Load Config
|
||||
config = litellm.ReplicateConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.ReplicateConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
|
||||
system_prompt = None
|
||||
if optional_params is not None and "supports_system_prompt" in optional_params:
|
||||
supports_sys_prompt = optional_params.pop("supports_system_prompt")
|
||||
else:
|
||||
supports_sys_prompt = False
|
||||
|
||||
|
||||
if supports_sys_prompt:
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] == "system":
|
||||
first_sys_message = messages.pop(i)
|
||||
system_prompt = first_sys_message["content"]
|
||||
break
|
||||
|
||||
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
|
@ -233,43 +272,58 @@ def completion(
|
|||
input_data = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt,
|
||||
**optional_params
|
||||
**optional_params,
|
||||
}
|
||||
# Otherwise, use the prompt as is
|
||||
else:
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
|
||||
input_data = {"prompt": prompt, **optional_params}
|
||||
|
||||
## COMPLETION CALL
|
||||
## Replicate Compeltion calls have 2 steps
|
||||
## Step1: Start Prediction: gets a prediction url
|
||||
## Step2: Poll prediction url for response
|
||||
## Step2: is handled with and without streaming
|
||||
model_response["created"] = int(time.time()) # for pricing this must remain right before calling api
|
||||
prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj, print_verbose=print_verbose)
|
||||
model_response["created"] = int(
|
||||
time.time()
|
||||
) # for pricing this must remain right before calling api
|
||||
prediction_url = start_prediction(
|
||||
version_id,
|
||||
input_data,
|
||||
api_key,
|
||||
api_base,
|
||||
logging_obj=logging_obj,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
print_verbose(prediction_url)
|
||||
|
||||
# Handle the prediction response (streaming or non-streaming)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
print_verbose("streaming request")
|
||||
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose)
|
||||
return handle_prediction_response_streaming(
|
||||
prediction_url, api_key, print_verbose
|
||||
)
|
||||
else:
|
||||
result, logs = handle_prediction_response(prediction_url, api_key, print_verbose)
|
||||
model_response["ended"] = time.time() # for pricing this must remain right after calling api
|
||||
result, logs = handle_prediction_response(
|
||||
prediction_url, api_key, print_verbose
|
||||
)
|
||||
model_response[
|
||||
"ended"
|
||||
] = time.time() # for pricing this must remain right after calling api
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=result,
|
||||
additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, },
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=result,
|
||||
additional_args={
|
||||
"complete_input_dict": input_data,
|
||||
"logs": logs,
|
||||
"api_base": prediction_url,
|
||||
},
|
||||
)
|
||||
|
||||
print_verbose(f"raw model_response: {result}")
|
||||
|
||||
if len(result) == 0: # edge case, where result from replicate is empty
|
||||
if len(result) == 0: # edge case, where result from replicate is empty
|
||||
result = " "
|
||||
|
||||
## Building RESPONSE OBJECT
|
||||
|
@ -278,12 +332,14 @@ def completion(
|
|||
|
||||
# Calculate usage
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(encoding.encode(model_response["choices"][0]["message"].get("content", "")))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
model_response["model"] = "replicate/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
|
|
@ -11,42 +11,61 @@ from copy import deepcopy
|
|||
import httpx
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
class SagemakerError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker")
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class SagemakerConfig():
|
||||
|
||||
class SagemakerConfig:
|
||||
"""
|
||||
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
|
||||
"""
|
||||
max_new_tokens: Optional[int]=None
|
||||
top_p: Optional[float]=None
|
||||
temperature: Optional[float]=None
|
||||
return_full_text: Optional[bool]=None
|
||||
|
||||
def __init__(self,
|
||||
max_new_tokens: Optional[int]=None,
|
||||
top_p: Optional[float]=None,
|
||||
temperature: Optional[float]=None,
|
||||
return_full_text: Optional[bool]=None) -> None:
|
||||
max_new_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
return_full_text: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
return_full_text: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
SAGEMAKER AUTH Keys/Vars
|
||||
os.environ['AWS_ACCESS_KEY_ID'] = ""
|
||||
|
@ -55,6 +74,7 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
|
|||
|
||||
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -85,28 +105,30 @@ def completion(
|
|||
region_name=aws_region_name,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automaticaly reads env variables
|
||||
|
||||
# we need to read region name from env
|
||||
# I assume majority of users use .env for auth
|
||||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME") or
|
||||
"us-west-2" # default to us-west-2 if user not specified
|
||||
get_secret("AWS_REGION_NAME")
|
||||
or "us-west-2" # default to us-west-2 if user not specified
|
||||
)
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
region_name=region_name,
|
||||
)
|
||||
|
||||
|
||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||
inference_params = deepcopy(optional_params)
|
||||
inference_params.pop("stream", None)
|
||||
|
||||
## Load Config
|
||||
config = litellm.SagemakerConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.SagemakerConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
model = model
|
||||
|
@ -114,25 +136,26 @@ def completion(
|
|||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
if hf_model_name is None:
|
||||
if "llama-2" in model.lower(): # llama-2 model
|
||||
if "chat" in model.lower(): # apply llama2 chat template
|
||||
if "llama-2" in model.lower(): # llama-2 model
|
||||
if "chat" in model.lower(): # apply llama2 chat template
|
||||
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
|
||||
else: # apply regular llama2 template
|
||||
else: # apply regular llama2 template
|
||||
hf_model_name = "meta-llama/Llama-2-7b"
|
||||
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||
hf_model_name = (
|
||||
hf_model_name or model
|
||||
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||
|
||||
data = json.dumps({
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params
|
||||
}).encode('utf-8')
|
||||
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
|
@ -142,31 +165,35 @@ def completion(
|
|||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
""" # type: ignore
|
||||
""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str, "hf_model_name": hf_model_name},
|
||||
)
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
"hf_model_name": hf_model_name,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
try:
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||
|
||||
|
||||
response = response["Body"].read().decode("utf8")
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = json.loads(response)
|
||||
|
@ -177,19 +204,20 @@ def completion(
|
|||
completion_output += completion_response_choices["generation"]
|
||||
elif "generated_text" in completion_response_choices:
|
||||
completion_output += completion_response_choices["generated_text"]
|
||||
|
||||
# check if the prompt template is part of output, if so - filter it out
|
||||
|
||||
# check if the prompt template is part of output, if so - filter it out
|
||||
if completion_output.startswith(prompt) and "<s>" in prompt:
|
||||
completion_output = completion_output.replace(prompt, "", 1)
|
||||
|
||||
model_response["choices"][0]["message"]["content"] = completion_output
|
||||
except:
|
||||
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
|
||||
raise SagemakerError(
|
||||
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
@ -197,28 +225,32 @@ def completion(
|
|||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
)
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
def embedding(model: str,
|
||||
input: list,
|
||||
model_response: EmbeddingResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None):
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
input: list,
|
||||
model_response: EmbeddingResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
"""
|
||||
Supports Huggingface Jumpstart embeddings like GPT-6B
|
||||
Supports Huggingface Jumpstart embeddings like GPT-6B
|
||||
"""
|
||||
### BOTO3 INIT
|
||||
import boto3
|
||||
import boto3
|
||||
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
|
@ -234,34 +266,34 @@ def embedding(model: str,
|
|||
region_name=aws_region_name,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automaticaly reads env variables
|
||||
|
||||
# we need to read region name from env
|
||||
# I assume majority of users use .env for auth
|
||||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME") or
|
||||
"us-west-2" # default to us-west-2 if user not specified
|
||||
get_secret("AWS_REGION_NAME")
|
||||
or "us-west-2" # default to us-west-2 if user not specified
|
||||
)
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
region_name=region_name,
|
||||
)
|
||||
|
||||
|
||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||
inference_params = deepcopy(optional_params)
|
||||
inference_params.pop("stream", None)
|
||||
|
||||
## Load Config
|
||||
config = litellm.SagemakerConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.SagemakerConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
#### HF EMBEDDING LOGIC
|
||||
data = json.dumps({
|
||||
"text_inputs": input
|
||||
}).encode('utf-8')
|
||||
#### HF EMBEDDING LOGIC
|
||||
data = json.dumps({"text_inputs": input}).encode("utf-8")
|
||||
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
|
@ -270,67 +302,65 @@ def embedding(model: str,
|
|||
ContentType="application/json",
|
||||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)""" # type: ignore
|
||||
)""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
)
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
)
|
||||
## EMBEDDING CALL
|
||||
try:
|
||||
try:
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
response = json.loads(response["Body"].read().decode("utf8"))
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=input,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
if "embedding" not in response:
|
||||
if "embedding" not in response:
|
||||
raise SagemakerError(status_code=500, message="embedding not found in response")
|
||||
embeddings = response['embedding']
|
||||
embeddings = response["embedding"]
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}")
|
||||
|
||||
raise SagemakerError(
|
||||
status_code=422, message=f"Response not in expected format - {embeddings}"
|
||||
)
|
||||
|
||||
output_data = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding
|
||||
}
|
||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
||||
)
|
||||
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = output_data
|
||||
model_response["model"] = model
|
||||
|
||||
|
||||
input_tokens = 0
|
||||
for text in input:
|
||||
input_tokens+=len(encoding.encode(text))
|
||||
input_tokens += len(encoding.encode(text))
|
||||
|
||||
model_response["usage"] = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
|
||||
model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens)
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -9,17 +9,21 @@ import httpx
|
|||
from litellm.utils import ModelResponse, Usage
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
class TogetherAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.together.xyz/inference")
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.together.xyz/inference"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class TogetherAIConfig():
|
||||
|
||||
class TogetherAIConfig:
|
||||
"""
|
||||
Reference: https://docs.together.ai/reference/inference
|
||||
|
||||
|
@ -37,35 +41,49 @@ class TogetherAIConfig():
|
|||
|
||||
- `repetition_penalty` (float, optional): A number that controls the diversity of generated text by reducing the likelihood of repeated sequences. Higher values decrease repetition.
|
||||
|
||||
- `logprobs` (int32, optional): This parameter is not described in the prompt.
|
||||
- `logprobs` (int32, optional): This parameter is not described in the prompt.
|
||||
"""
|
||||
max_tokens: Optional[int]=None
|
||||
stop: Optional[str]=None
|
||||
temperature:Optional[int]=None
|
||||
top_p: Optional[float]=None
|
||||
top_k: Optional[int]=None
|
||||
repetition_penalty: Optional[float]=None
|
||||
logprobs: Optional[int]=None
|
||||
|
||||
def __init__(self,
|
||||
max_tokens: Optional[int]=None,
|
||||
stop: Optional[str]=None,
|
||||
temperature:Optional[int]=None,
|
||||
top_p: Optional[float]=None,
|
||||
top_k: Optional[int]=None,
|
||||
repetition_penalty: Optional[float]=None,
|
||||
logprobs: Optional[int]=None) -> None:
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[str] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
logprobs: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop: Optional[str] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
|
@ -80,10 +98,11 @@ def validate_environment(api_key):
|
|||
}
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
|
@ -97,9 +116,11 @@ def completion(
|
|||
headers = validate_environment(api_key)
|
||||
|
||||
## Load Config
|
||||
config = litellm.TogetherAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.TogetherAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
|
||||
|
@ -107,15 +128,20 @@ def completion(
|
|||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list
|
||||
prompt = prompt_factory(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
custom_llm_provider="together_ai",
|
||||
) # api key required to query together ai model list
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
|
@ -128,13 +154,14 @@ def completion(
|
|||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data, "headers": headers, "api_base": api_base},
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if (
|
||||
"stream_tokens" in optional_params
|
||||
and optional_params["stream_tokens"] == True
|
||||
):
|
||||
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
|
@ -143,18 +170,14 @@ def completion(
|
|||
)
|
||||
return response.iter_lines()
|
||||
else:
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data)
|
||||
)
|
||||
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
if response.status_code != 200:
|
||||
|
@ -170,30 +193,38 @@ def completion(
|
|||
)
|
||||
elif "error" in completion_response["output"]:
|
||||
raise TogetherAIError(
|
||||
message=json.dumps(completion_response["output"]), status_code=response.status_code
|
||||
message=json.dumps(completion_response["output"]),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
|
||||
if len(completion_response["output"]["choices"][0]["text"]) >= 0:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response["output"]["choices"][0]["text"]
|
||||
model_response["choices"][0]["message"]["content"] = completion_response[
|
||||
"output"
|
||||
]["choices"][0]["text"]
|
||||
|
||||
## CALCULATING USAGE
|
||||
print_verbose(f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}")
|
||||
print_verbose(
|
||||
f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
|
||||
)
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
if "finish_reason" in completion_response["output"]["choices"][0]:
|
||||
model_response.choices[0].finish_reason = completion_response["output"]["choices"][0]["finish_reason"]
|
||||
model_response.choices[0].finish_reason = completion_response["output"][
|
||||
"choices"
|
||||
][0]["finish_reason"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
|||
import litellm
|
||||
import httpx
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url=" https://cloud.google.com/vertex-ai/")
|
||||
self.request = httpx.Request(
|
||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class VertexAIConfig():
|
||||
|
||||
class VertexAIConfig:
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
||||
|
||||
|
@ -34,28 +38,42 @@ class VertexAIConfig():
|
|||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
temperature: Optional[float]=None
|
||||
max_output_tokens: Optional[int]=None
|
||||
top_p: Optional[float]=None
|
||||
top_k: Optional[int]=None
|
||||
|
||||
def __init__(self,
|
||||
temperature: Optional[float]=None,
|
||||
max_output_tokens: Optional[int]=None,
|
||||
top_p: Optional[float]=None,
|
||||
top_k: Optional[int]=None) -> None:
|
||||
|
||||
temperature: Optional[float] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: Optional[float] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def _get_image_bytes_from_url(image_url: str) -> bytes:
|
||||
try:
|
||||
|
@ -65,7 +83,7 @@ def _get_image_bytes_from_url(image_url: str) -> bytes:
|
|||
return image_bytes
|
||||
except requests.exceptions.RequestException as e:
|
||||
# Handle any request exceptions (e.g., connection error, timeout)
|
||||
return b'' # Return an empty bytes object or handle the error as needed
|
||||
return b"" # Return an empty bytes object or handle the error as needed
|
||||
|
||||
|
||||
def _load_image_from_url(image_url: str):
|
||||
|
@ -78,13 +96,18 @@ def _load_image_from_url(image_url: str):
|
|||
Returns:
|
||||
Image: The loaded image.
|
||||
"""
|
||||
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
|
||||
from vertexai.preview.generative_models import (
|
||||
GenerativeModel,
|
||||
Part,
|
||||
GenerationConfig,
|
||||
Image,
|
||||
)
|
||||
|
||||
image_bytes = _get_image_bytes_from_url(image_url)
|
||||
return Image.from_bytes(image_bytes)
|
||||
|
||||
def _gemini_vision_convert_messages(
|
||||
messages: list
|
||||
):
|
||||
|
||||
def _gemini_vision_convert_messages(messages: list):
|
||||
"""
|
||||
Converts given messages for GPT-4 Vision to Gemini format.
|
||||
|
||||
|
@ -95,7 +118,7 @@ def _gemini_vision_convert_messages(
|
|||
|
||||
Returns:
|
||||
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
|
||||
|
||||
|
||||
Raises:
|
||||
VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed.
|
||||
Exception: If any other exception occurs during the execution of the function.
|
||||
|
@ -115,11 +138,23 @@ def _gemini_vision_convert_messages(
|
|||
try:
|
||||
import vertexai
|
||||
except:
|
||||
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`")
|
||||
try:
|
||||
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||
)
|
||||
try:
|
||||
from vertexai.preview.language_models import (
|
||||
ChatModel,
|
||||
CodeChatModel,
|
||||
InputOutputTextPair,
|
||||
)
|
||||
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
||||
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
|
||||
from vertexai.preview.generative_models import (
|
||||
GenerativeModel,
|
||||
Part,
|
||||
GenerationConfig,
|
||||
Image,
|
||||
)
|
||||
|
||||
# given messages for gpt-4 vision, convert them for gemini
|
||||
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
|
||||
|
@ -159,6 +194,7 @@ def _gemini_vision_convert_messages(
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -171,30 +207,38 @@ def completion(
|
|||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool=False
|
||||
acompletion: bool = False,
|
||||
):
|
||||
try:
|
||||
import vertexai
|
||||
except:
|
||||
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`")
|
||||
try:
|
||||
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||
)
|
||||
try:
|
||||
from vertexai.preview.language_models import (
|
||||
ChatModel,
|
||||
CodeChatModel,
|
||||
InputOutputTextPair,
|
||||
)
|
||||
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
||||
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
|
||||
from vertexai.preview.generative_models import (
|
||||
GenerativeModel,
|
||||
Part,
|
||||
GenerationConfig,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
|
||||
|
||||
|
||||
vertexai.init(
|
||||
project=vertex_project, location=vertex_location
|
||||
)
|
||||
vertexai.init(project=vertex_project, location=vertex_location)
|
||||
|
||||
## Load Config
|
||||
config = litellm.VertexAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params:
|
||||
for k, v in config.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
## Process safety settings into format expected by vertex AI
|
||||
## Process safety settings into format expected by vertex AI
|
||||
safety_settings = None
|
||||
if "safety_settings" in optional_params:
|
||||
safety_settings = optional_params.pop("safety_settings")
|
||||
|
@ -202,17 +246,25 @@ def completion(
|
|||
raise ValueError("safety_settings must be a list")
|
||||
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
|
||||
raise ValueError("safety_settings must be a list of dicts")
|
||||
safety_settings=[gapic_content_types.SafetySetting(x) for x in safety_settings]
|
||||
safety_settings = [
|
||||
gapic_content_types.SafetySetting(x) for x in safety_settings
|
||||
]
|
||||
|
||||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)])
|
||||
prompt = " ".join(
|
||||
[
|
||||
message["content"]
|
||||
for message in messages
|
||||
if isinstance(message["content"], str)
|
||||
]
|
||||
)
|
||||
|
||||
mode = ""
|
||||
mode = ""
|
||||
|
||||
request_str = ""
|
||||
response_obj = None
|
||||
if model in litellm.vertex_language_models:
|
||||
if model in litellm.vertex_language_models:
|
||||
llm_model = GenerativeModel(model)
|
||||
mode = ""
|
||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||
|
@ -232,31 +284,76 @@ def completion(
|
|||
llm_model = CodeGenerationModel.from_pretrained(model)
|
||||
mode = "text"
|
||||
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
|
||||
else: # vertex_code_llm_models
|
||||
else: # vertex_code_llm_models
|
||||
llm_model = CodeChatModel.from_pretrained(model)
|
||||
mode = "chat"
|
||||
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
||||
|
||||
if acompletion == True: # [TODO] expand support to vertex ai chat + text models
|
||||
if optional_params.get("stream", False) is True:
|
||||
|
||||
if acompletion == True: # [TODO] expand support to vertex ai chat + text models
|
||||
if optional_params.get("stream", False) is True:
|
||||
# async streaming
|
||||
return async_streaming(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, messages=messages, print_verbose=print_verbose, **optional_params)
|
||||
return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, messages=messages,print_verbose=print_verbose,**optional_params)
|
||||
return async_streaming(
|
||||
llm_model=llm_model,
|
||||
mode=mode,
|
||||
prompt=prompt,
|
||||
logging_obj=logging_obj,
|
||||
request_str=request_str,
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
**optional_params,
|
||||
)
|
||||
return async_completion(
|
||||
llm_model=llm_model,
|
||||
mode=mode,
|
||||
prompt=prompt,
|
||||
logging_obj=logging_obj,
|
||||
request_str=request_str,
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
if mode == "":
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
model_response = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream)
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
model_response = llm_model.generate_content(
|
||||
prompt,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
safety_settings=safety_settings,
|
||||
stream=stream,
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response_obj = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings)
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response_obj = llm_model.generate_content(
|
||||
prompt,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
completion_response = response_obj.text
|
||||
response_obj = response_obj._raw_response
|
||||
elif mode == "vision":
|
||||
|
@ -268,21 +365,35 @@ def completion(
|
|||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
model_response = llm_model.generate_content(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
safety_settings=safety_settings,
|
||||
stream=True
|
||||
stream=True,
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
|
||||
request_str += f"response = llm_model.generate_content({content})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
## LLM Call
|
||||
response = llm_model.generate_content(
|
||||
contents=content,
|
||||
|
@ -293,88 +404,150 @@ def completion(
|
|||
response_obj = response._raw_response
|
||||
elif mode == "chat":
|
||||
chat = llm_model.start_chat()
|
||||
request_str+= f"chat = llm_model.start_chat()\n"
|
||||
request_str += f"chat = llm_model.start_chat()\n"
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
|
||||
# we handle this by removing 'stream' from optional params and sending the request
|
||||
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
|
||||
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
|
||||
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n"
|
||||
optional_params.pop(
|
||||
"stream", None
|
||||
) # vertex ai raises an error when passing stream in optional params
|
||||
request_str += (
|
||||
f"chat.send_message_streaming({prompt}, **{optional_params})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
model_response = chat.send_message_streaming(prompt, **optional_params)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
|
||||
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
completion_response = chat.send_message(prompt, **optional_params).text
|
||||
elif mode == "text":
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
|
||||
request_str += f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
|
||||
optional_params.pop(
|
||||
"stream", None
|
||||
) # See note above on handling streaming for vertex ai
|
||||
request_str += (
|
||||
f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
model_response = llm_model.predict_streaming(prompt, **optional_params)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
|
||||
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
completion_response = llm_model.predict(prompt, **optional_params).text
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt, api_key=None, original_response=completion_response
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
if len(str(completion_response)) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = str(completion_response)
|
||||
if len(str(completion_response)) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = str(
|
||||
completion_response
|
||||
)
|
||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
## CALCULATING USAGE
|
||||
if model in litellm.vertex_language_models and response_obj is not None:
|
||||
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
|
||||
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
||||
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
||||
total_tokens=response_obj.usage_metadata.total_token_count)
|
||||
else:
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
model_response["choices"][0].finish_reason = response_obj.candidates[
|
||||
0
|
||||
].finish_reason.name
|
||||
usage = Usage(
|
||||
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
||||
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
||||
total_tokens=response_obj.usage_metadata.total_token_count,
|
||||
)
|
||||
else:
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", "")
|
||||
)
|
||||
)
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
)
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, encoding=None, messages = None, print_verbose = None, **optional_params):
|
||||
|
||||
async def async_completion(
|
||||
llm_model,
|
||||
mode: str,
|
||||
prompt: str,
|
||||
model: str,
|
||||
model_response: ModelResponse,
|
||||
logging_obj=None,
|
||||
request_str=None,
|
||||
encoding=None,
|
||||
messages=None,
|
||||
print_verbose=None,
|
||||
**optional_params,
|
||||
):
|
||||
"""
|
||||
Add support for acompletion calls for gemini-pro
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
from vertexai.preview.generative_models import GenerationConfig
|
||||
|
||||
if mode == "":
|
||||
# gemini-pro
|
||||
chat = llm_model.start_chat()
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response_obj = await chat.send_message_async(
|
||||
prompt, generation_config=GenerationConfig(**optional_params)
|
||||
)
|
||||
completion_response = response_obj.text
|
||||
response_obj = response_obj._raw_response
|
||||
elif mode == "vision":
|
||||
|
@ -386,12 +559,18 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
|
|||
|
||||
request_str += f"response = llm_model.generate_content({content})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
## LLM Call
|
||||
response = await llm_model._generate_content_async(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params)
|
||||
contents=content, generation_config=GenerationConfig(**optional_params)
|
||||
)
|
||||
completion_response = response.text
|
||||
response_obj = response._raw_response
|
||||
|
@ -399,14 +578,28 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
|
|||
# chat-bison etc.
|
||||
chat = llm_model.start_chat()
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response_obj = await chat.send_message_async(prompt, **optional_params)
|
||||
completion_response = response_obj.text
|
||||
elif mode == "text":
|
||||
# gecko etc.
|
||||
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response_obj = await llm_model.predict_async(prompt, **optional_params)
|
||||
completion_response = response_obj.text
|
||||
|
||||
|
@ -416,51 +609,77 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
|
|||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
if len(str(completion_response)) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = str(completion_response)
|
||||
if len(str(completion_response)) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = str(
|
||||
completion_response
|
||||
)
|
||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
## CALCULATING USAGE
|
||||
if model in litellm.vertex_language_models and response_obj is not None:
|
||||
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
|
||||
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
||||
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
||||
total_tokens=response_obj.usage_metadata.total_token_count)
|
||||
model_response["choices"][0].finish_reason = response_obj.candidates[
|
||||
0
|
||||
].finish_reason.name
|
||||
usage = Usage(
|
||||
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
||||
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
||||
total_tokens=response_obj.usage_metadata.total_token_count,
|
||||
)
|
||||
else:
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", "")
|
||||
)
|
||||
)
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
)
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, messages = None, print_verbose = None, **optional_params):
|
||||
|
||||
async def async_streaming(
|
||||
llm_model,
|
||||
mode: str,
|
||||
prompt: str,
|
||||
model: str,
|
||||
model_response: ModelResponse,
|
||||
logging_obj=None,
|
||||
request_str=None,
|
||||
messages=None,
|
||||
print_verbose=None,
|
||||
**optional_params,
|
||||
):
|
||||
"""
|
||||
Add support for async streaming calls for gemini-pro
|
||||
"""
|
||||
from vertexai.preview.generative_models import GenerationConfig
|
||||
if mode == "":
|
||||
|
||||
if mode == "":
|
||||
# gemini-pro
|
||||
chat = llm_model.start_chat()
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response = await chat.send_message_async(
|
||||
prompt, generation_config=GenerationConfig(**optional_params), stream=stream
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
elif mode == "vision":
|
||||
elif mode == "vision":
|
||||
stream = optional_params.pop("stream")
|
||||
|
||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||
|
@ -470,33 +689,68 @@ async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_r
|
|||
content = [prompt] + images
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
response = llm_model._generate_content_streaming_async(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
stream=True
|
||||
stream=True,
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
elif mode == "chat":
|
||||
chat = llm_model.start_chat()
|
||||
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
|
||||
request_str += f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
|
||||
optional_params.pop(
|
||||
"stream", None
|
||||
) # vertex ai raises an error when passing stream in optional params
|
||||
request_str += (
|
||||
f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response = chat.send_message_streaming_async(prompt, **optional_params)
|
||||
optional_params["stream"] = True
|
||||
elif mode == "text":
|
||||
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
|
||||
request_str += f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
|
||||
optional_params.pop(
|
||||
"stream", None
|
||||
) # See note above on handling streaming for vertex ai
|
||||
request_str += (
|
||||
f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="vertex_ai",logging_obj=logging_obj)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -6,7 +6,10 @@ import time, httpx
|
|||
from typing import Callable, Any
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
llm = None
|
||||
|
||||
|
||||
class VLLMError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
|
@ -17,17 +20,20 @@ class VLLMError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
# check if vllm is installed
|
||||
def validate_environment(model: str):
|
||||
global llm
|
||||
try:
|
||||
from vllm import LLM, SamplingParams # type: ignore
|
||||
try:
|
||||
from vllm import LLM, SamplingParams # type: ignore
|
||||
|
||||
if llm is None:
|
||||
llm = LLM(model=model)
|
||||
return llm, SamplingParams
|
||||
except Exception as e:
|
||||
raise VLLMError(status_code=0, message=str(e))
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -50,15 +56,14 @@ def completion(
|
|||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -69,9 +74,10 @@ def completion(
|
|||
if llm:
|
||||
outputs = llm.generate(prompt, sampling_params)
|
||||
else:
|
||||
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm")
|
||||
raise VLLMError(
|
||||
status_code=0, message="Need to pass in a model name to initialize vllm"
|
||||
)
|
||||
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
return iter(outputs)
|
||||
|
@ -88,24 +94,22 @@ def completion(
|
|||
model_response["choices"][0]["message"]["content"] = outputs[0].outputs[0].text
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(outputs[0].prompt_token_ids)
|
||||
completion_tokens = len(outputs[0].outputs[0].token_ids)
|
||||
prompt_tokens = len(outputs[0].prompt_token_ids)
|
||||
completion_tokens = len(outputs[0].outputs[0].token_ids)
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def batch_completions(
|
||||
model: str,
|
||||
messages: list,
|
||||
optional_params=None,
|
||||
custom_prompt_dict={}
|
||||
model: str, messages: list, optional_params=None, custom_prompt_dict={}
|
||||
):
|
||||
"""
|
||||
Example usage:
|
||||
|
@ -137,31 +141,33 @@ def batch_completions(
|
|||
except Exception as e:
|
||||
error_str = str(e)
|
||||
if "data parallel group is already initialized" in error_str:
|
||||
pass
|
||||
pass
|
||||
else:
|
||||
raise VLLMError(status_code=0, message=error_str)
|
||||
sampling_params = SamplingParams(**optional_params)
|
||||
prompts = []
|
||||
prompts = []
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
for message in messages:
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=message
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=message,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
else:
|
||||
for message in messages:
|
||||
prompt = prompt_factory(model=model, messages=message)
|
||||
prompts.append(prompt)
|
||||
|
||||
|
||||
if llm:
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
else:
|
||||
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm")
|
||||
raise VLLMError(
|
||||
status_code=0, message="Need to pass in a model name to initialize vllm"
|
||||
)
|
||||
|
||||
final_outputs = []
|
||||
for output in outputs:
|
||||
|
@ -170,20 +176,21 @@ def batch_completions(
|
|||
model_response["choices"][0]["message"]["content"] = output.outputs[0].text
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(output.prompt_token_ids)
|
||||
completion_tokens = len(output.outputs[0].token_ids)
|
||||
prompt_tokens = len(output.prompt_token_ids)
|
||||
completion_tokens = len(output.outputs[0].token_ids)
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
final_outputs.append(model_response)
|
||||
return final_outputs
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
1733
litellm/main.py
1733
litellm/main.py
File diff suppressed because it is too large
Load diff
|
@ -1 +1 @@
|
|||
from . import *
|
||||
from . import *
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
def my_custom_rule(input): # receives the model response
|
||||
def my_custom_rule(input): # receives the model response
|
||||
# if len(input) < 5: # trigger fallback if the model response is too short
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
|
|
@ -3,13 +3,15 @@ from typing import Optional, List, Union, Dict, Literal
|
|||
from datetime import datetime
|
||||
import uuid, json
|
||||
|
||||
|
||||
class LiteLLMBase(BaseModel):
|
||||
"""
|
||||
Implements default functions, all pydantic objects should have.
|
||||
"""
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
@ -34,7 +36,7 @@ class ProxyChatCompletionRequest(LiteLLMBase):
|
|||
tools: Optional[List[str]] = None
|
||||
tool_choice: Optional[str] = None
|
||||
functions: Optional[List[str]] = None # soon to be deprecated
|
||||
function_call: Optional[str] = None # soon to be deprecated
|
||||
function_call: Optional[str] = None # soon to be deprecated
|
||||
|
||||
# Optional LiteLLM params
|
||||
caching: Optional[bool] = None
|
||||
|
@ -49,7 +51,8 @@ class ProxyChatCompletionRequest(LiteLLMBase):
|
|||
request_timeout: Optional[int] = None
|
||||
|
||||
class Config:
|
||||
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||
|
||||
|
||||
class ModelInfoDelete(LiteLLMBase):
|
||||
id: Optional[str]
|
||||
|
@ -57,38 +60,37 @@ class ModelInfoDelete(LiteLLMBase):
|
|||
|
||||
class ModelInfo(LiteLLMBase):
|
||||
id: Optional[str]
|
||||
mode: Optional[Literal['embedding', 'chat', 'completion']]
|
||||
mode: Optional[Literal["embedding", "chat", "completion"]]
|
||||
input_cost_per_token: Optional[float] = 0.0
|
||||
output_cost_per_token: Optional[float] = 0.0
|
||||
max_tokens: Optional[int] = 2048 # assume 2048 if not set
|
||||
max_tokens: Optional[int] = 2048 # assume 2048 if not set
|
||||
|
||||
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
|
||||
# we look up the base model in model_prices_and_context_window.json
|
||||
base_model: Optional[Literal
|
||||
[
|
||||
'gpt-4-1106-preview',
|
||||
'gpt-4-32k',
|
||||
'gpt-4',
|
||||
'gpt-3.5-turbo-16k',
|
||||
'gpt-3.5-turbo',
|
||||
'text-embedding-ada-002',
|
||||
]
|
||||
]
|
||||
base_model: Optional[
|
||||
Literal[
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-32k",
|
||||
"gpt-4",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo",
|
||||
"text-embedding-ada-002",
|
||||
]
|
||||
]
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow # Allow extra fields
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("id") is None:
|
||||
values.update({"id": str(uuid.uuid4())})
|
||||
if values.get("mode") is None:
|
||||
if values.get("mode") is None:
|
||||
values.update({"mode": None})
|
||||
if values.get("input_cost_per_token") is None:
|
||||
if values.get("input_cost_per_token") is None:
|
||||
values.update({"input_cost_per_token": None})
|
||||
if values.get("output_cost_per_token") is None:
|
||||
if values.get("output_cost_per_token") is None:
|
||||
values.update({"output_cost_per_token": None})
|
||||
if values.get("max_tokens") is None:
|
||||
values.update({"max_tokens": None})
|
||||
|
@ -97,21 +99,21 @@ class ModelInfo(LiteLLMBase):
|
|||
return values
|
||||
|
||||
|
||||
|
||||
class ModelParams(LiteLLMBase):
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
model_info: ModelInfo
|
||||
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("model_info") is None:
|
||||
values.update({"model_info": ModelInfo()})
|
||||
return values
|
||||
|
||||
|
||||
class GenerateKeyRequest(LiteLLMBase):
|
||||
duration: Optional[str] = "1h"
|
||||
models: Optional[list] = []
|
||||
|
@ -122,6 +124,7 @@ class GenerateKeyRequest(LiteLLMBase):
|
|||
max_parallel_requests: Optional[int] = None
|
||||
metadata: Optional[dict] = {}
|
||||
|
||||
|
||||
class UpdateKeyRequest(LiteLLMBase):
|
||||
key: str
|
||||
duration: Optional[str] = None
|
||||
|
@ -133,10 +136,12 @@ class UpdateKeyRequest(LiteLLMBase):
|
|||
max_parallel_requests: Optional[int] = None
|
||||
metadata: Optional[dict] = {}
|
||||
|
||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||
|
||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||
"""
|
||||
Return the row in the db
|
||||
"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
models: list = []
|
||||
aliases: dict = {}
|
||||
|
@ -147,45 +152,84 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
|
|||
duration: str = "1h"
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class GenerateKeyResponse(LiteLLMBase):
|
||||
key: str
|
||||
expires: Optional[datetime]
|
||||
user_id: str
|
||||
|
||||
|
||||
class _DeleteKeyObject(LiteLLMBase):
|
||||
key: str
|
||||
|
||||
|
||||
class DeleteKeyRequest(LiteLLMBase):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
|
||||
class NewUserRequest(GenerateKeyRequest):
|
||||
max_budget: Optional[float] = None
|
||||
|
||||
|
||||
class NewUserResponse(GenerateKeyResponse):
|
||||
max_budget: Optional[float] = None
|
||||
|
||||
|
||||
class ConfigGeneralSettings(LiteLLMBase):
|
||||
"""
|
||||
Documents all the fields supported by `general_settings` in config.yaml
|
||||
"""
|
||||
completion_model: Optional[str] = Field(None, description="proxy level default model for all chat completion calls")
|
||||
use_azure_key_vault: Optional[bool] = Field(None, description="load keys from azure key vault")
|
||||
master_key: Optional[str] = Field(None, description="require a key for all calls to proxy")
|
||||
database_url: Optional[str] = Field(None, description="connect to a postgres db - needed for generating temporary keys + tracking spend / key")
|
||||
otel: Optional[bool] = Field(None, description="[BETA] OpenTelemetry support - this might change, use with caution.")
|
||||
custom_auth: Optional[str] = Field(None, description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth")
|
||||
max_parallel_requests: Optional[int] = Field(None, description="maximum parallel requests for each api key")
|
||||
infer_model_from_keys: Optional[bool] = Field(None, description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)")
|
||||
background_health_checks: Optional[bool] = Field(None, description="run health checks in background")
|
||||
health_check_interval: int = Field(300, description="background health check interval in seconds")
|
||||
|
||||
|
||||
completion_model: Optional[str] = Field(
|
||||
None, description="proxy level default model for all chat completion calls"
|
||||
)
|
||||
use_azure_key_vault: Optional[bool] = Field(
|
||||
None, description="load keys from azure key vault"
|
||||
)
|
||||
master_key: Optional[str] = Field(
|
||||
None, description="require a key for all calls to proxy"
|
||||
)
|
||||
database_url: Optional[str] = Field(
|
||||
None,
|
||||
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
|
||||
)
|
||||
otel: Optional[bool] = Field(
|
||||
None,
|
||||
description="[BETA] OpenTelemetry support - this might change, use with caution.",
|
||||
)
|
||||
custom_auth: Optional[str] = Field(
|
||||
None,
|
||||
description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth",
|
||||
)
|
||||
max_parallel_requests: Optional[int] = Field(
|
||||
None, description="maximum parallel requests for each api key"
|
||||
)
|
||||
infer_model_from_keys: Optional[bool] = Field(
|
||||
None,
|
||||
description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)",
|
||||
)
|
||||
background_health_checks: Optional[bool] = Field(
|
||||
None, description="run health checks in background"
|
||||
)
|
||||
health_check_interval: int = Field(
|
||||
300, description="background health check interval in seconds"
|
||||
)
|
||||
|
||||
|
||||
class ConfigYAML(LiteLLMBase):
|
||||
"""
|
||||
Documents all the fields supported by the config.yaml
|
||||
"""
|
||||
model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs")
|
||||
litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache")
|
||||
|
||||
model_list: Optional[List[ModelParams]] = Field(
|
||||
None,
|
||||
description="List of supported models on the server, with model-specific configs",
|
||||
)
|
||||
litellm_settings: Optional[dict] = Field(
|
||||
None,
|
||||
description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache",
|
||||
)
|
||||
general_settings: Optional[ConfigGeneralSettings] = None
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from fastapi import Request
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"
|
||||
if api_key == modified_master_key:
|
||||
return UserAPIKeyAuth(api_key=api_key)
|
||||
raise Exception
|
||||
except:
|
||||
raise Exception
|
||||
except:
|
||||
raise Exception
|
||||
|
|
|
@ -4,17 +4,19 @@ import sys, os, traceback
|
|||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import litellm
|
||||
import inspect
|
||||
|
||||
|
||||
# This file includes the custom callbacks for LiteLLM Proxy
|
||||
# Once defined, these can be passed in proxy_config.yaml
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
def __init__(self):
|
||||
|
@ -23,36 +25,38 @@ class MyCustomHandler(CustomLogger):
|
|||
print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger")
|
||||
try:
|
||||
print_verbose(f"Logger Initialized with following methods:")
|
||||
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
|
||||
|
||||
methods = [
|
||||
method
|
||||
for method in dir(self)
|
||||
if inspect.ismethod(getattr(self, method))
|
||||
]
|
||||
|
||||
# Pretty print_verbose the methods
|
||||
for method in methods:
|
||||
print_verbose(f" - {method}")
|
||||
print_verbose(f"{reset_color_code}")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print_verbose(f"Pre-API Call")
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose(f"Post-API Call")
|
||||
|
||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose(f"On Stream")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose("On Success!")
|
||||
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose(f"On Async Success!")
|
||||
response_cost = litellm.completion_cost(completion_response=response_obj)
|
||||
assert response_cost > 0.0
|
||||
return
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
print_verbose(f"On Async Failure !")
|
||||
except Exception as e:
|
||||
|
@ -64,4 +68,4 @@ proxy_handler_instance = MyCustomHandler()
|
|||
|
||||
# need to set litellm.callbacks = [customHandler] # on the proxy
|
||||
|
||||
# litellm.success_callback = [async_on_succes_logger]
|
||||
# litellm.success_callback = [async_on_succes_logger]
|
||||
|
|
|
@ -12,25 +12,16 @@ from litellm._logging import print_verbose
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ILLEGAL_DISPLAY_PARAMS = [
|
||||
"messages",
|
||||
"api_key"
|
||||
]
|
||||
ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key"]
|
||||
|
||||
|
||||
def _get_random_llm_message():
|
||||
"""
|
||||
Get a random message from the LLM.
|
||||
"""
|
||||
messages = [
|
||||
"Hey how's it going?",
|
||||
"What's 1 + 1?"
|
||||
]
|
||||
messages = ["Hey how's it going?", "What's 1 + 1?"]
|
||||
|
||||
|
||||
return [
|
||||
{"role": "user", "content": random.choice(messages)}
|
||||
]
|
||||
return [{"role": "user", "content": random.choice(messages)}]
|
||||
|
||||
|
||||
def _clean_litellm_params(litellm_params: dict):
|
||||
|
@ -44,34 +35,40 @@ async def _perform_health_check(model_list: list):
|
|||
"""
|
||||
Perform a health check for each model in the list.
|
||||
"""
|
||||
|
||||
async def _check_img_gen_model(model_params: dict):
|
||||
model_params.pop("messages", None)
|
||||
model_params["prompt"] = "test from litellm"
|
||||
try:
|
||||
await litellm.aimage_generation(**model_params)
|
||||
except Exception as e:
|
||||
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
|
||||
print_verbose(
|
||||
f"Health check failed for model {model_params['model']}. Error: {e}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _check_embedding_model(model_params: dict):
|
||||
model_params.pop("messages", None)
|
||||
model_params["input"] = ["test from litellm"]
|
||||
try:
|
||||
await litellm.aembedding(**model_params)
|
||||
except Exception as e:
|
||||
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
|
||||
print_verbose(
|
||||
f"Health check failed for model {model_params['model']}. Error: {e}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _check_model(model_params: dict):
|
||||
try:
|
||||
await litellm.acompletion(**model_params)
|
||||
except Exception as e:
|
||||
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Health check failed for model {model_params['model']}. Error: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
tasks = []
|
||||
|
@ -104,9 +101,9 @@ async def _perform_health_check(model_list: list):
|
|||
return healthy_endpoints, unhealthy_endpoints
|
||||
|
||||
|
||||
|
||||
|
||||
async def perform_health_check(model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None):
|
||||
async def perform_health_check(
|
||||
model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Perform a health check on the system.
|
||||
|
||||
|
@ -115,7 +112,9 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl
|
|||
"""
|
||||
if not model_list:
|
||||
if cli_model:
|
||||
model_list = [{"model_name": cli_model, "litellm_params": {"model": cli_model}}]
|
||||
model_list = [
|
||||
{"model_name": cli_model, "litellm_params": {"model": cli_model}}
|
||||
]
|
||||
else:
|
||||
return [], []
|
||||
|
||||
|
@ -125,5 +124,3 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl
|
|||
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)
|
||||
|
||||
return healthy_endpoints, unhealthy_endpoints
|
||||
|
||||
|
|
@ -6,35 +6,42 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
from fastapi import HTTPException
|
||||
import json, traceback
|
||||
|
||||
class MaxBudgetLimiter(CustomLogger):
|
||||
|
||||
class MaxBudgetLimiter(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
|
||||
try:
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
try:
|
||||
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
|
||||
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
|
||||
user_row = cache.get_cache(cache_key)
|
||||
if user_row is None: # value not yet cached
|
||||
return
|
||||
if user_row is None: # value not yet cached
|
||||
return
|
||||
max_budget = user_row["max_budget"]
|
||||
curr_spend = user_row["spend"]
|
||||
|
||||
if max_budget is None:
|
||||
return
|
||||
|
||||
if curr_spend is None:
|
||||
return
|
||||
|
||||
|
||||
if curr_spend is None:
|
||||
return
|
||||
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
if curr_spend >= max_budget:
|
||||
raise HTTPException(status_code=429, detail="Max budget limit reached.")
|
||||
except HTTPException as e:
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
traceback.print_exc()
|
||||
|
|
|
@ -5,18 +5,25 @@ from litellm.proxy._types import UserAPIKeyAuth
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
|
||||
class MaxParallelRequestsHandler(CustomLogger):
|
||||
|
||||
class MaxParallelRequestsHandler(CustomLogger):
|
||||
user_api_key_cache = None
|
||||
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
|
||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||
|
@ -26,8 +33,8 @@ class MaxParallelRequestsHandler(CustomLogger):
|
|||
|
||||
if max_parallel_requests is None:
|
||||
return
|
||||
|
||||
self.user_api_key_cache = cache # save the api key cache for updating the value
|
||||
|
||||
self.user_api_key_cache = cache # save the api key cache for updating the value
|
||||
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
|
@ -35,56 +42,67 @@ class MaxParallelRequestsHandler(CustomLogger):
|
|||
self.print_verbose(f"current: {current}")
|
||||
if current is None:
|
||||
cache.set_cache(request_count_api_key, 1)
|
||||
elif int(current) < max_parallel_requests:
|
||||
elif int(current) < max_parallel_requests:
|
||||
# Increase count for this token
|
||||
cache.set_cache(request_count_api_key, int(current) + 1)
|
||||
else:
|
||||
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Max parallel request limit reached."
|
||||
)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
try:
|
||||
self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING")
|
||||
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
||||
if user_api_key is None:
|
||||
return
|
||||
|
||||
if self.user_api_key_cache is None:
|
||||
|
||||
if self.user_api_key_cache is None:
|
||||
return
|
||||
|
||||
|
||||
request_count_api_key = f"{user_api_key}_request_count"
|
||||
# check if it has collected an entire stream response
|
||||
self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}")
|
||||
self.print_verbose(
|
||||
f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}"
|
||||
)
|
||||
if "complete_streaming_response" in kwargs or kwargs["stream"] != True:
|
||||
# Decrease count for this token
|
||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
current = (
|
||||
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
)
|
||||
new_val = current - 1
|
||||
self.print_verbose(f"updated_value in success call: {new_val}")
|
||||
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
|
||||
except Exception as e:
|
||||
self.print_verbose(e) # noqa
|
||||
except Exception as e:
|
||||
self.print_verbose(e) # noqa
|
||||
|
||||
async def async_log_failure_call(self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception):
|
||||
async def async_log_failure_call(
|
||||
self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception
|
||||
):
|
||||
try:
|
||||
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
if api_key is None:
|
||||
return
|
||||
|
||||
if self.user_api_key_cache is None:
|
||||
|
||||
if self.user_api_key_cache is None:
|
||||
return
|
||||
|
||||
|
||||
## decrement call count if call failed
|
||||
if (hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
and "Max parallel request limit reached" in str(original_exception)):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
if (
|
||||
hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
and "Max parallel request limit reached" in str(original_exception)
|
||||
):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
# Decrease count for this token
|
||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
current = (
|
||||
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
)
|
||||
new_val = current - 1
|
||||
self.print_verbose(f"updated_value in failure call: {new_val}")
|
||||
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
|
||||
except Exception as e:
|
||||
self.print_verbose(f"An exception occurred - {str(e)}") # noqa
|
||||
self.print_verbose(f"An exception occurred - {str(e)}") # noqa
|
||||
|
|
|
@ -6,85 +6,202 @@ from datetime import datetime
|
|||
import importlib
|
||||
from dotenv import load_dotenv
|
||||
import operator
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
config_filename = "litellm.secrets"
|
||||
# Using appdirs to determine user-specific config path
|
||||
config_dir = appdirs.user_config_dir("litellm")
|
||||
user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
|
||||
user_config_path = os.getenv(
|
||||
"LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
from importlib import resources
|
||||
import shutil
|
||||
|
||||
telemetry = None
|
||||
|
||||
|
||||
def run_ollama_serve():
|
||||
try:
|
||||
command = ['ollama', 'serve']
|
||||
|
||||
with open(os.devnull, 'w') as devnull:
|
||||
command = ["ollama", "serve"]
|
||||
|
||||
with open(os.devnull, "w") as devnull:
|
||||
process = subprocess.Popen(command, stdout=devnull, stderr=devnull)
|
||||
except Exception as e:
|
||||
print(f"""
|
||||
print(
|
||||
f"""
|
||||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||
""") # noqa
|
||||
"""
|
||||
) # noqa
|
||||
|
||||
|
||||
def clone_subfolder(repo_url, subfolder, destination):
|
||||
# Clone the full repo
|
||||
repo_name = repo_url.split('/')[-1]
|
||||
repo_master = os.path.join(destination, "repo_master")
|
||||
subprocess.run(['git', 'clone', repo_url, repo_master])
|
||||
# Clone the full repo
|
||||
repo_name = repo_url.split("/")[-1]
|
||||
repo_master = os.path.join(destination, "repo_master")
|
||||
subprocess.run(["git", "clone", repo_url, repo_master])
|
||||
|
||||
# Move into the subfolder
|
||||
subfolder_path = os.path.join(repo_master, subfolder)
|
||||
# Move into the subfolder
|
||||
subfolder_path = os.path.join(repo_master, subfolder)
|
||||
|
||||
# Copy subfolder to destination
|
||||
for file_name in os.listdir(subfolder_path):
|
||||
source = os.path.join(subfolder_path, file_name)
|
||||
if os.path.isfile(source):
|
||||
shutil.copy(source, destination)
|
||||
else:
|
||||
dest_path = os.path.join(destination, file_name)
|
||||
shutil.copytree(source, dest_path)
|
||||
# Copy subfolder to destination
|
||||
for file_name in os.listdir(subfolder_path):
|
||||
source = os.path.join(subfolder_path, file_name)
|
||||
if os.path.isfile(source):
|
||||
shutil.copy(source, destination)
|
||||
else:
|
||||
dest_path = os.path.join(destination, file_name)
|
||||
shutil.copytree(source, dest_path)
|
||||
|
||||
# Remove cloned repo folder
|
||||
subprocess.run(["rm", "-rf", os.path.join(destination, "repo_master")])
|
||||
feature_telemetry(feature="create-proxy")
|
||||
|
||||
# Remove cloned repo folder
|
||||
subprocess.run(['rm', '-rf', os.path.join(destination, "repo_master")])
|
||||
feature_telemetry(feature="create-proxy")
|
||||
|
||||
def is_port_in_use(port):
|
||||
import socket
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--host', default='0.0.0.0', help='Host for the server to listen on.')
|
||||
@click.option('--port', default=8000, help='Port to bind the server to.')
|
||||
@click.option('--num_workers', default=1, help='Number of uvicorn workers to spin up')
|
||||
@click.option('--api_base', default=None, help='API base URL.')
|
||||
@click.option('--api_version', default="2023-07-01-preview", help='For azure - pass in the api version.')
|
||||
@click.option('--model', '-m', default=None, help='The model name to pass to litellm expects')
|
||||
@click.option('--alias', default=None, help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")')
|
||||
@click.option('--add_key', default=None, help='The model name to pass to litellm expects')
|
||||
@click.option('--headers', default=None, help='headers for the API call')
|
||||
@click.option('--save', is_flag=True, type=bool, help='Save the model-specific config')
|
||||
@click.option('--debug', default=False, is_flag=True, type=bool, help='To debug the input')
|
||||
@click.option('--use_queue', default=False, is_flag=True, type=bool, help='To use celery workers for async endpoints')
|
||||
@click.option('--temperature', default=None, type=float, help='Set temperature for the model')
|
||||
@click.option('--max_tokens', default=None, type=int, help='Set max tokens for the model')
|
||||
@click.option('--request_timeout', default=600, type=int, help='Set timeout in seconds for completion calls')
|
||||
@click.option('--drop_params', is_flag=True, help='Drop any unmapped params')
|
||||
@click.option('--add_function_to_prompt', is_flag=True, help='If function passed but unsupported, pass it as prompt')
|
||||
@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`')
|
||||
@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`')
|
||||
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
|
||||
@click.option('--version', '-v', default=False, is_flag=True, type=bool, help='Print LiteLLM version')
|
||||
@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.')
|
||||
@click.option('--health', flag_value=True, help='Make a chat/completions request to all llms in config.yaml')
|
||||
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to')
|
||||
@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response')
|
||||
@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with')
|
||||
@click.option('--local', is_flag=True, default=False, help='for local debugging')
|
||||
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health, version):
|
||||
@click.option("--host", default="0.0.0.0", help="Host for the server to listen on.")
|
||||
@click.option("--port", default=8000, help="Port to bind the server to.")
|
||||
@click.option("--num_workers", default=1, help="Number of uvicorn workers to spin up")
|
||||
@click.option("--api_base", default=None, help="API base URL.")
|
||||
@click.option(
|
||||
"--api_version",
|
||||
default="2023-07-01-preview",
|
||||
help="For azure - pass in the api version.",
|
||||
)
|
||||
@click.option(
|
||||
"--model", "-m", default=None, help="The model name to pass to litellm expects"
|
||||
)
|
||||
@click.option(
|
||||
"--alias",
|
||||
default=None,
|
||||
help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")',
|
||||
)
|
||||
@click.option(
|
||||
"--add_key", default=None, help="The model name to pass to litellm expects"
|
||||
)
|
||||
@click.option("--headers", default=None, help="headers for the API call")
|
||||
@click.option("--save", is_flag=True, type=bool, help="Save the model-specific config")
|
||||
@click.option(
|
||||
"--debug", default=False, is_flag=True, type=bool, help="To debug the input"
|
||||
)
|
||||
@click.option(
|
||||
"--use_queue",
|
||||
default=False,
|
||||
is_flag=True,
|
||||
type=bool,
|
||||
help="To use celery workers for async endpoints",
|
||||
)
|
||||
@click.option(
|
||||
"--temperature", default=None, type=float, help="Set temperature for the model"
|
||||
)
|
||||
@click.option(
|
||||
"--max_tokens", default=None, type=int, help="Set max tokens for the model"
|
||||
)
|
||||
@click.option(
|
||||
"--request_timeout",
|
||||
default=600,
|
||||
type=int,
|
||||
help="Set timeout in seconds for completion calls",
|
||||
)
|
||||
@click.option("--drop_params", is_flag=True, help="Drop any unmapped params")
|
||||
@click.option(
|
||||
"--add_function_to_prompt",
|
||||
is_flag=True,
|
||||
help="If function passed but unsupported, pass it as prompt",
|
||||
)
|
||||
@click.option(
|
||||
"--config",
|
||||
"-c",
|
||||
default=None,
|
||||
help="Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`",
|
||||
)
|
||||
@click.option(
|
||||
"--max_budget",
|
||||
default=None,
|
||||
type=float,
|
||||
help="Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`",
|
||||
)
|
||||
@click.option(
|
||||
"--telemetry",
|
||||
default=True,
|
||||
type=bool,
|
||||
help="Helps us know if people are using this feature. Turn this off by doing `--telemetry False`",
|
||||
)
|
||||
@click.option(
|
||||
"--version",
|
||||
"-v",
|
||||
default=False,
|
||||
is_flag=True,
|
||||
type=bool,
|
||||
help="Print LiteLLM version",
|
||||
)
|
||||
@click.option(
|
||||
"--logs",
|
||||
flag_value=False,
|
||||
type=int,
|
||||
help='Gets the "n" most recent logs. By default gets most recent log.',
|
||||
)
|
||||
@click.option(
|
||||
"--health",
|
||||
flag_value=True,
|
||||
help="Make a chat/completions request to all llms in config.yaml",
|
||||
)
|
||||
@click.option(
|
||||
"--test",
|
||||
flag_value=True,
|
||||
help="proxy chat completions url to make a test request to",
|
||||
)
|
||||
@click.option(
|
||||
"--test_async",
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Calls async endpoints /queue/requests and /queue/response",
|
||||
)
|
||||
@click.option(
|
||||
"--num_requests",
|
||||
default=10,
|
||||
type=int,
|
||||
help="Number of requests to hit async endpoint with",
|
||||
)
|
||||
@click.option("--local", is_flag=True, default=False, help="for local debugging")
|
||||
def run_server(
|
||||
host,
|
||||
port,
|
||||
api_base,
|
||||
api_version,
|
||||
model,
|
||||
alias,
|
||||
add_key,
|
||||
headers,
|
||||
save,
|
||||
debug,
|
||||
temperature,
|
||||
max_tokens,
|
||||
request_timeout,
|
||||
drop_params,
|
||||
add_function_to_prompt,
|
||||
config,
|
||||
max_budget,
|
||||
telemetry,
|
||||
logs,
|
||||
test,
|
||||
local,
|
||||
num_workers,
|
||||
test_async,
|
||||
num_requests,
|
||||
use_queue,
|
||||
health,
|
||||
version,
|
||||
):
|
||||
global feature_telemetry
|
||||
args = locals()
|
||||
if local:
|
||||
|
@ -92,51 +209,60 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
|||
else:
|
||||
try:
|
||||
from .proxy_server import app, save_worker_config, usage_telemetry
|
||||
except ImportError as e:
|
||||
except ImportError as e:
|
||||
from proxy_server import app, save_worker_config, usage_telemetry
|
||||
feature_telemetry = usage_telemetry
|
||||
if logs is not None:
|
||||
if logs == 0: # default to 1
|
||||
if logs == 0: # default to 1
|
||||
logs = 1
|
||||
try:
|
||||
with open('api_log.json') as f:
|
||||
with open("api_log.json") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# convert keys to datetime objects
|
||||
log_times = {datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items()}
|
||||
# convert keys to datetime objects
|
||||
log_times = {
|
||||
datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items()
|
||||
}
|
||||
|
||||
# sort by timestamp
|
||||
sorted_times = sorted(log_times.items(), key=operator.itemgetter(0), reverse=True)
|
||||
# sort by timestamp
|
||||
sorted_times = sorted(
|
||||
log_times.items(), key=operator.itemgetter(0), reverse=True
|
||||
)
|
||||
|
||||
# get n recent logs
|
||||
recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]}
|
||||
recent_logs = {
|
||||
k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]
|
||||
}
|
||||
|
||||
print(json.dumps(recent_logs, indent=4)) # noqa
|
||||
print(json.dumps(recent_logs, indent=4)) # noqa
|
||||
except:
|
||||
raise Exception("LiteLLM: No logs saved!")
|
||||
return
|
||||
if version == True:
|
||||
pkg_version = importlib.metadata.version("litellm")
|
||||
click.echo(f'\nLiteLLM: Current Version = {pkg_version}\n')
|
||||
click.echo(f"\nLiteLLM: Current Version = {pkg_version}\n")
|
||||
return
|
||||
if model and "ollama" in model and api_base is None:
|
||||
if model and "ollama" in model and api_base is None:
|
||||
run_ollama_serve()
|
||||
if test_async is True:
|
||||
if test_async is True:
|
||||
import requests, concurrent, time
|
||||
|
||||
api_base = f"http://{host}:{port}"
|
||||
|
||||
def _make_openai_completion():
|
||||
def _make_openai_completion():
|
||||
data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Write a short poem about the moon"}]
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write a short poem about the moon"}
|
||||
],
|
||||
}
|
||||
|
||||
response = requests.post("http://0.0.0.0:8000/queue/request", json=data)
|
||||
|
||||
response = response.json()
|
||||
|
||||
while True:
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
url = response["url"]
|
||||
polling_url = f"{api_base}{url}"
|
||||
polling_response = requests.get(polling_url)
|
||||
|
@ -146,7 +272,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
|||
if status == "finished":
|
||||
llm_response = polling_response["result"]
|
||||
break
|
||||
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa
|
||||
print(
|
||||
f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}"
|
||||
) # noqa
|
||||
time.sleep(0.5)
|
||||
except Exception as e:
|
||||
print("got exception in polling", e)
|
||||
|
@ -159,7 +287,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
|||
futures = []
|
||||
start_time = time.time()
|
||||
# Make concurrent calls
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_calls) as executor:
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=concurrent_calls
|
||||
) as executor:
|
||||
for _ in range(concurrent_calls):
|
||||
futures.append(executor.submit(_make_openai_completion))
|
||||
|
||||
|
@ -171,7 +301,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
|||
failed_calls = 0
|
||||
|
||||
for future in futures:
|
||||
if future.done():
|
||||
if future.done():
|
||||
if future.result() is not None:
|
||||
successful_calls += 1
|
||||
else:
|
||||
|
@ -185,58 +315,86 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
|||
return
|
||||
if health != False:
|
||||
import requests
|
||||
|
||||
print("\nLiteLLM: Health Testing models in config")
|
||||
response = requests.get(url=f"http://{host}:{port}/health")
|
||||
print(json.dumps(response.json(), indent=4))
|
||||
return
|
||||
if test != False:
|
||||
click.echo('\nLiteLLM: Making a test ChatCompletions request to your proxy')
|
||||
click.echo("\nLiteLLM: Making a test ChatCompletions request to your proxy")
|
||||
import openai
|
||||
if test == True: # flag value set
|
||||
api_base = f"http://{host}:{port}"
|
||||
else:
|
||||
api_base = test
|
||||
client = openai.OpenAI(
|
||||
api_key="My API Key",
|
||||
base_url=api_base
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
], max_tokens=256)
|
||||
click.echo(f'\nLiteLLM: response from proxy {response}')
|
||||
if test == True: # flag value set
|
||||
api_base = f"http://{host}:{port}"
|
||||
else:
|
||||
api_base = test
|
||||
client = openai.OpenAI(api_key="My API Key", base_url=api_base)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem",
|
||||
}
|
||||
],
|
||||
max_tokens=256,
|
||||
)
|
||||
click.echo(f"\nLiteLLM: response from proxy {response}")
|
||||
|
||||
print("\n Making streaming request to proxy")
|
||||
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem",
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
for chunk in response:
|
||||
click.echo(f'LiteLLM: streaming response from proxy {chunk}')
|
||||
click.echo(f"LiteLLM: streaming response from proxy {chunk}")
|
||||
print("\n making completion request to proxy")
|
||||
response = client.completions.create(model="gpt-3.5-turbo", prompt='this is a test request, write a short poem')
|
||||
response = client.completions.create(
|
||||
model="gpt-3.5-turbo", prompt="this is a test request, write a short poem"
|
||||
)
|
||||
print(response)
|
||||
|
||||
return
|
||||
else:
|
||||
if headers:
|
||||
headers = json.loads(headers)
|
||||
save_worker_config(model=model, alias=alias, api_base=api_base, api_version=api_version, debug=debug, temperature=temperature, max_tokens=max_tokens, request_timeout=request_timeout, max_budget=max_budget, telemetry=telemetry, drop_params=drop_params, add_function_to_prompt=add_function_to_prompt, headers=headers, save=save, config=config, use_queue=use_queue)
|
||||
save_worker_config(
|
||||
model=model,
|
||||
alias=alias,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
debug=debug,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
request_timeout=request_timeout,
|
||||
max_budget=max_budget,
|
||||
telemetry=telemetry,
|
||||
drop_params=drop_params,
|
||||
add_function_to_prompt=add_function_to_prompt,
|
||||
headers=headers,
|
||||
save=save,
|
||||
config=config,
|
||||
use_queue=use_queue,
|
||||
)
|
||||
try:
|
||||
import uvicorn
|
||||
except:
|
||||
raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`")
|
||||
raise ImportError(
|
||||
"Uvicorn needs to be imported. Run - `pip install uvicorn`"
|
||||
)
|
||||
if port == 8000 and is_port_in_use(port):
|
||||
port = random.randint(1024, 49152)
|
||||
uvicorn.run("litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers)
|
||||
uvicorn.run(
|
||||
"litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,71 +1,77 @@
|
|||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
import json, subprocess
|
||||
import psutil # Import the psutil library
|
||||
import atexit
|
||||
try:
|
||||
### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both
|
||||
from celery import Celery
|
||||
import redis
|
||||
except:
|
||||
import sys
|
||||
# from dotenv import load_dotenv
|
||||
|
||||
subprocess.check_call(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pip",
|
||||
"install",
|
||||
"redis",
|
||||
"celery"
|
||||
]
|
||||
)
|
||||
# load_dotenv()
|
||||
# import json, subprocess
|
||||
# import psutil # Import the psutil library
|
||||
# import atexit
|
||||
|
||||
import time
|
||||
import sys, os
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path - for litellm local dev
|
||||
import litellm
|
||||
# try:
|
||||
# ### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both
|
||||
# from celery import Celery
|
||||
# import redis
|
||||
# except:
|
||||
# import sys
|
||||
|
||||
# Redis connection setup
|
||||
pool = redis.ConnectionPool(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"), db=0, max_connections=5)
|
||||
redis_client = redis.Redis(connection_pool=pool)
|
||||
# subprocess.check_call([sys.executable, "-m", "pip", "install", "redis", "celery"])
|
||||
|
||||
# Celery setup
|
||||
celery_app = Celery('tasks', broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}", backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}")
|
||||
celery_app.conf.update(
|
||||
broker_pool_limit = None,
|
||||
broker_transport_options = {'connection_pool': pool},
|
||||
result_backend_transport_options = {'connection_pool': pool},
|
||||
)
|
||||
# import time
|
||||
# import sys, os
|
||||
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../../..")
|
||||
# ) # Adds the parent directory to the system path - for litellm local dev
|
||||
# import litellm
|
||||
|
||||
# # Redis connection setup
|
||||
# pool = redis.ConnectionPool(
|
||||
# host=os.getenv("REDIS_HOST"),
|
||||
# port=os.getenv("REDIS_PORT"),
|
||||
# password=os.getenv("REDIS_PASSWORD"),
|
||||
# db=0,
|
||||
# max_connections=5,
|
||||
# )
|
||||
# redis_client = redis.Redis(connection_pool=pool)
|
||||
|
||||
# # Celery setup
|
||||
# celery_app = Celery(
|
||||
# "tasks",
|
||||
# broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}",
|
||||
# backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}",
|
||||
# )
|
||||
# celery_app.conf.update(
|
||||
# broker_pool_limit=None,
|
||||
# broker_transport_options={"connection_pool": pool},
|
||||
# result_backend_transport_options={"connection_pool": pool},
|
||||
# )
|
||||
|
||||
|
||||
# Celery task
|
||||
@celery_app.task(name='process_job', max_retries=3)
|
||||
def process_job(*args, **kwargs):
|
||||
try:
|
||||
llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore
|
||||
response = llm_router.completion(*args, **kwargs) # type: ignore
|
||||
if isinstance(response, litellm.ModelResponse):
|
||||
response = response.model_dump_json()
|
||||
return json.loads(response)
|
||||
return str(response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
# # Celery task
|
||||
# @celery_app.task(name="process_job", max_retries=3)
|
||||
# def process_job(*args, **kwargs):
|
||||
# try:
|
||||
# llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore
|
||||
# response = llm_router.completion(*args, **kwargs) # type: ignore
|
||||
# if isinstance(response, litellm.ModelResponse):
|
||||
# response = response.model_dump_json()
|
||||
# return json.loads(response)
|
||||
# return str(response)
|
||||
# except Exception as e:
|
||||
# raise e
|
||||
|
||||
# Ensure Celery workers are terminated when the script exits
|
||||
def cleanup():
|
||||
try:
|
||||
# Get a list of all running processes
|
||||
for process in psutil.process_iter(attrs=['pid', 'name']):
|
||||
# Check if the process is a Celery worker process
|
||||
if process.info['name'] == 'celery':
|
||||
print(f"Terminating Celery worker with PID {process.info['pid']}")
|
||||
# Terminate the Celery worker process
|
||||
psutil.Process(process.info['pid']).terminate()
|
||||
except Exception as e:
|
||||
print(f"Error during cleanup: {e}")
|
||||
|
||||
# Register the cleanup function to run when the script exits
|
||||
atexit.register(cleanup)
|
||||
# # Ensure Celery workers are terminated when the script exits
|
||||
# def cleanup():
|
||||
# try:
|
||||
# # Get a list of all running processes
|
||||
# for process in psutil.process_iter(attrs=["pid", "name"]):
|
||||
# # Check if the process is a Celery worker process
|
||||
# if process.info["name"] == "celery":
|
||||
# print(f"Terminating Celery worker with PID {process.info['pid']}")
|
||||
# # Terminate the Celery worker process
|
||||
# psutil.Process(process.info["pid"]).terminate()
|
||||
# except Exception as e:
|
||||
# print(f"Error during cleanup: {e}")
|
||||
|
||||
|
||||
# # Register the cleanup function to run when the script exits
|
||||
# atexit.register(cleanup)
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import os
|
||||
from multiprocessing import Process
|
||||
|
||||
|
||||
def run_worker(cwd):
|
||||
os.chdir(cwd)
|
||||
os.system("celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info")
|
||||
os.system(
|
||||
"celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info"
|
||||
)
|
||||
|
||||
|
||||
def start_worker(cwd):
|
||||
cwd += "/queue"
|
||||
worker_process = Process(target=run_worker, args=(cwd,))
|
||||
worker_process.start()
|
||||
|
|
@ -1,26 +1,34 @@
|
|||
import sys, os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
# Add the path to the local folder to sys.path
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path - for litellm local dev
|
||||
# import sys, os
|
||||
# from dotenv import load_dotenv
|
||||
|
||||
def start_rq_worker():
|
||||
from rq import Worker, Queue, Connection
|
||||
from redis import Redis
|
||||
# Set up RQ connection
|
||||
redis_conn = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
||||
print(redis_conn.ping()) # Should print True if connected successfully
|
||||
# Create a worker and add the queue
|
||||
try:
|
||||
queue = Queue(connection=redis_conn)
|
||||
worker = Worker([queue], connection=redis_conn)
|
||||
except Exception as e:
|
||||
print(f"Error setting up worker: {e}")
|
||||
exit()
|
||||
|
||||
with Connection(redis_conn):
|
||||
worker.work()
|
||||
# load_dotenv()
|
||||
# # Add the path to the local folder to sys.path
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../../..")
|
||||
# ) # Adds the parent directory to the system path - for litellm local dev
|
||||
|
||||
start_rq_worker()
|
||||
|
||||
# def start_rq_worker():
|
||||
# from rq import Worker, Queue, Connection
|
||||
# from redis import Redis
|
||||
|
||||
# # Set up RQ connection
|
||||
# redis_conn = Redis(
|
||||
# host=os.getenv("REDIS_HOST"),
|
||||
# port=os.getenv("REDIS_PORT"),
|
||||
# password=os.getenv("REDIS_PASSWORD"),
|
||||
# )
|
||||
# print(redis_conn.ping()) # Should print True if connected successfully
|
||||
# # Create a worker and add the queue
|
||||
# try:
|
||||
# queue = Queue(connection=redis_conn)
|
||||
# worker = Worker([queue], connection=redis_conn)
|
||||
# except Exception as e:
|
||||
# print(f"Error setting up worker: {e}")
|
||||
# exit()
|
||||
|
||||
# with Connection(redis_conn):
|
||||
# worker.work()
|
||||
|
||||
|
||||
# start_rq_worker()
|
||||
|
|
|
@ -4,20 +4,18 @@ import uuid
|
|||
import traceback
|
||||
|
||||
|
||||
litellm_client = AsyncOpenAI(
|
||||
api_key="test",
|
||||
base_url="http://0.0.0.0:8000"
|
||||
)
|
||||
litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000")
|
||||
|
||||
|
||||
async def litellm_completion():
|
||||
# Your existing code for litellm_completion goes here
|
||||
try:
|
||||
response = await litellm_client.chat.completions.create(
|
||||
response = await litellm_client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"*180}], # this is about 4k tokens per request
|
||||
messages=[
|
||||
{"role": "user", "content": f"This is a test: {uuid.uuid4()}" * 180}
|
||||
], # this is about 4k tokens per request
|
||||
)
|
||||
print(response)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
@ -25,7 +23,6 @@ async def litellm_completion():
|
|||
with open("error_log.txt", "a") as error_log:
|
||||
error_log.write(f"Error during completion: {str(e)}\n")
|
||||
pass
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
|
@ -45,6 +42,7 @@ async def main():
|
|||
|
||||
print(n, time.time() - start, len(successful_completions))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Blank out contents of error_log.txt
|
||||
open("error_log.txt", "w").close()
|
||||
|
|
|
@ -4,16 +4,13 @@ import uuid
|
|||
import traceback
|
||||
|
||||
|
||||
litellm_client = AsyncOpenAI(
|
||||
api_key="sk-1234",
|
||||
base_url="http://0.0.0.0:8000"
|
||||
)
|
||||
litellm_client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:8000")
|
||||
|
||||
|
||||
async def litellm_completion():
|
||||
# Your existing code for litellm_completion goes here
|
||||
try:
|
||||
response = await litellm_client.chat.completions.create(
|
||||
response = await litellm_client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
)
|
||||
|
@ -24,7 +21,6 @@ async def litellm_completion():
|
|||
with open("error_log.txt", "a") as error_log:
|
||||
error_log.write(f"Error during completion: {str(e)}\n")
|
||||
pass
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
|
@ -44,6 +40,7 @@ async def main():
|
|||
|
||||
print(n, time.time() - start, len(successful_completions))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Blank out contents of error_log.txt
|
||||
open("error_log.txt", "w").close()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# test time it takes to make 100 concurrent embedding requests to OpenaI
|
||||
# test time it takes to make 100 concurrent embedding requests to OpenaI
|
||||
|
||||
import sys, os
|
||||
import traceback
|
||||
|
@ -14,16 +14,16 @@ import pytest
|
|||
|
||||
|
||||
import litellm
|
||||
litellm.set_verbose=False
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
||||
|
||||
question = "embed this very long text" * 100
|
||||
|
||||
|
||||
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
|
||||
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
|
||||
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
|
||||
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
|
||||
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
|
||||
|
||||
import concurrent.futures
|
||||
import random
|
||||
|
@ -35,7 +35,10 @@ def make_openai_completion(question):
|
|||
try:
|
||||
start_time = time.time()
|
||||
import openai
|
||||
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY']) #base_url="http://0.0.0.0:8000",
|
||||
|
||||
client = openai.OpenAI(
|
||||
api_key=os.environ["OPENAI_API_KEY"]
|
||||
) # base_url="http://0.0.0.0:8000",
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-ada-002",
|
||||
input=[question],
|
||||
|
@ -58,6 +61,7 @@ def make_openai_completion(question):
|
|||
# )
|
||||
return None
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
# Number of concurrent calls (you can adjust this)
|
||||
concurrent_calls = 500
|
||||
|
|
|
@ -4,24 +4,20 @@ import uuid
|
|||
import traceback
|
||||
|
||||
|
||||
litellm_client = AsyncOpenAI(
|
||||
api_key="test",
|
||||
base_url="http://0.0.0.0:8000"
|
||||
)
|
||||
|
||||
litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000")
|
||||
|
||||
|
||||
async def litellm_completion():
|
||||
# Your existing code for litellm_completion goes here
|
||||
try:
|
||||
print("starting embedding calls")
|
||||
response = await litellm_client.embeddings.create(
|
||||
response = await litellm_client.embeddings.create(
|
||||
model="text-embedding-ada-002",
|
||||
input = [
|
||||
"hello who are you" * 2000,
|
||||
input=[
|
||||
"hello who are you" * 2000,
|
||||
"hello who are you tomorrow 1234" * 1000,
|
||||
"hello who are you tomorrow 1234" * 1000
|
||||
]
|
||||
"hello who are you tomorrow 1234" * 1000,
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
return response
|
||||
|
@ -31,7 +27,6 @@ async def litellm_completion():
|
|||
with open("error_log.txt", "a") as error_log:
|
||||
error_log.write(f"Error during completion: {str(e)}\n")
|
||||
pass
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
|
@ -51,6 +46,7 @@ async def main():
|
|||
|
||||
print(n, time.time() - start, len(successful_completions))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Blank out contents of error_log.txt
|
||||
open("error_log.txt", "w").close()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# test time it takes to make 100 concurrent embedding requests to OpenaI
|
||||
# test time it takes to make 100 concurrent embedding requests to OpenaI
|
||||
|
||||
import sys, os
|
||||
import traceback
|
||||
|
@ -14,16 +14,16 @@ import pytest
|
|||
|
||||
|
||||
import litellm
|
||||
litellm.set_verbose=False
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
||||
|
||||
question = "embed this very long text" * 100
|
||||
|
||||
|
||||
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
|
||||
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
|
||||
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
|
||||
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
|
||||
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
|
||||
|
||||
import concurrent.futures
|
||||
import random
|
||||
|
@ -35,7 +35,10 @@ def make_openai_completion(question):
|
|||
try:
|
||||
start_time = time.time()
|
||||
import openai
|
||||
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000",
|
||||
|
||||
client = openai.OpenAI(
|
||||
api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000"
|
||||
) # base_url="http://0.0.0.0:8000",
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-ada-002",
|
||||
input=[question],
|
||||
|
@ -58,6 +61,7 @@ def make_openai_completion(question):
|
|||
# )
|
||||
return None
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
# Number of concurrent calls (you can adjust this)
|
||||
concurrent_calls = 500
|
||||
|
|
|
@ -2,6 +2,7 @@ import requests
|
|||
import time
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
|
@ -12,37 +13,35 @@ base_url = "https://api.litellm.ai"
|
|||
|
||||
# Step 1 Add a config to the proxy, generate a temp key
|
||||
config = {
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.environ['OPENAI_API_KEY'],
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.environ['AZURE_API_KEY'],
|
||||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
|
||||
"api_version": "2023-07-01-preview"
|
||||
}
|
||||
}
|
||||
]
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.environ["AZURE_API_KEY"],
|
||||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
|
||||
"api_version": "2023-07-01-preview",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
print("STARTING LOAD TEST Q")
|
||||
print(os.environ['AZURE_API_KEY'])
|
||||
print(os.environ["AZURE_API_KEY"])
|
||||
|
||||
response = requests.post(
|
||||
url=f"{base_url}/key/generate",
|
||||
json={
|
||||
"config": config,
|
||||
"duration": "30d" # default to 30d, set it to 30m if you want a temp key
|
||||
"duration": "30d", # default to 30d, set it to 30m if you want a temp key
|
||||
},
|
||||
headers={
|
||||
"Authorization": "Bearer sk-hosted-litellm"
|
||||
}
|
||||
headers={"Authorization": "Bearer sk-hosted-litellm"},
|
||||
)
|
||||
|
||||
print("\nresponse from generating key", response.text)
|
||||
|
@ -56,19 +55,18 @@ print("\ngenerated key for proxy", generated_key)
|
|||
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
def create_job_and_poll(request_num):
|
||||
print(f"Creating a job on the proxy for request {request_num}")
|
||||
job_response = requests.post(
|
||||
url=f"{base_url}/queue/request",
|
||||
json={
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'messages': [
|
||||
{'role': 'system', 'content': 'write a short poem'},
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "system", "content": "write a short poem"},
|
||||
],
|
||||
},
|
||||
headers={
|
||||
"Authorization": f"Bearer {generated_key}"
|
||||
}
|
||||
headers={"Authorization": f"Bearer {generated_key}"},
|
||||
)
|
||||
print(job_response.status_code)
|
||||
print(job_response.text)
|
||||
|
@ -84,12 +82,12 @@ def create_job_and_poll(request_num):
|
|||
try:
|
||||
print(f"\nPolling URL for request {request_num}", polling_url)
|
||||
polling_response = requests.get(
|
||||
url=polling_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {generated_key}"
|
||||
}
|
||||
url=polling_url, headers={"Authorization": f"Bearer {generated_key}"}
|
||||
)
|
||||
print(
|
||||
f"\nResponse from polling url for request {request_num}",
|
||||
polling_response.text,
|
||||
)
|
||||
print(f"\nResponse from polling url for request {request_num}", polling_response.text)
|
||||
polling_response = polling_response.json()
|
||||
status = polling_response.get("status", None)
|
||||
if status == "finished":
|
||||
|
@ -109,6 +107,7 @@ def create_job_and_poll(request_num):
|
|||
except Exception as e:
|
||||
print("got exception when polling", e)
|
||||
|
||||
|
||||
# Number of requests
|
||||
num_requests = 100
|
||||
|
||||
|
@ -118,4 +117,4 @@ with concurrent.futures.ThreadPoolExecutor(max_workers=num_requests) as executor
|
|||
futures = [executor.submit(create_job_and_poll, i) for i in range(num_requests)]
|
||||
|
||||
# Wait for all futures to complete
|
||||
concurrent.futures.wait(futures)
|
||||
concurrent.futures.wait(futures)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# # This tests the litelm proxy
|
||||
# # This tests the litelm proxy
|
||||
# # it makes async Completion requests with streaming
|
||||
# import openai
|
||||
|
||||
|
@ -8,14 +8,14 @@
|
|||
|
||||
# async def test_async_completion():
|
||||
# response = await (
|
||||
# model="gpt-3.5-turbo",
|
||||
# model="gpt-3.5-turbo",
|
||||
# prompt='this is a test request, write a short poem',
|
||||
# )
|
||||
# print(response)
|
||||
|
||||
# print("test_streaming")
|
||||
# response = await openai.chat.completions.create(
|
||||
# model="gpt-3.5-turbo",
|
||||
# model="gpt-3.5-turbo",
|
||||
# prompt='this is a test request, write a short poem',
|
||||
# stream=True
|
||||
# )
|
||||
|
@ -26,4 +26,3 @@
|
|||
|
||||
# import asyncio
|
||||
# asyncio.run(test_async_completion())
|
||||
|
||||
|
|
|
@ -34,7 +34,3 @@
|
|||
# response = claude_chat(messages)
|
||||
|
||||
# print(response)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import requests
|
|||
import time
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
|
@ -12,26 +13,24 @@ base_url = "https://api.litellm.ai"
|
|||
|
||||
# Step 1 Add a config to the proxy, generate a temp key
|
||||
config = {
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.environ['OPENAI_API_KEY'],
|
||||
}
|
||||
}
|
||||
]
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{base_url}/key/generate",
|
||||
json={
|
||||
"config": config,
|
||||
"duration": "30d" # default to 30d, set it to 30m if you want a temp key
|
||||
"duration": "30d", # default to 30d, set it to 30m if you want a temp key
|
||||
},
|
||||
headers={
|
||||
"Authorization": "Bearer sk-hosted-litellm"
|
||||
}
|
||||
headers={"Authorization": "Bearer sk-hosted-litellm"},
|
||||
)
|
||||
|
||||
print("\nresponse from generating key", response.text)
|
||||
|
@ -45,22 +44,23 @@ print("Creating a job on the proxy")
|
|||
job_response = requests.post(
|
||||
url=f"{base_url}/queue/request",
|
||||
json={
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'messages': [
|
||||
{'role': 'system', 'content': f'You are a helpful assistant. What is your name'},
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a helpful assistant. What is your name",
|
||||
},
|
||||
],
|
||||
},
|
||||
headers={
|
||||
"Authorization": f"Bearer {generated_key}"
|
||||
}
|
||||
headers={"Authorization": f"Bearer {generated_key}"},
|
||||
)
|
||||
print(job_response.status_code)
|
||||
print(job_response.text)
|
||||
print("\nResponse from creating job", job_response.text)
|
||||
job_response = job_response.json()
|
||||
job_id = job_response["id"] # type: ignore
|
||||
polling_url = job_response["url"] # type: ignore
|
||||
polling_url = f"{base_url}{polling_url}"
|
||||
job_id = job_response["id"] # type: ignore
|
||||
polling_url = job_response["url"] # type: ignore
|
||||
polling_url = f"{base_url}{polling_url}"
|
||||
print("\nCreated Job, Polling Url", polling_url)
|
||||
|
||||
# Step 3: Poll the request
|
||||
|
@ -68,16 +68,13 @@ while True:
|
|||
try:
|
||||
print("\nPolling URL", polling_url)
|
||||
polling_response = requests.get(
|
||||
url=polling_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {generated_key}"
|
||||
}
|
||||
url=polling_url, headers={"Authorization": f"Bearer {generated_key}"}
|
||||
)
|
||||
print("\nResponse from polling url", polling_response.text)
|
||||
polling_response = polling_response.json()
|
||||
status = polling_response.get("status", None) # type: ignore
|
||||
status = polling_response.get("status", None) # type: ignore
|
||||
if status == "finished":
|
||||
llm_response = polling_response["result"] # type: ignore
|
||||
llm_response = polling_response["result"] # type: ignore
|
||||
print("LLM Response")
|
||||
print(llm_response)
|
||||
break
|
||||
|
|
|
@ -8,16 +8,19 @@ from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
||||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
||||
|
||||
|
||||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
"""
|
||||
Logging/Custom Handlers for proxy.
|
||||
Logging/Custom Handlers for proxy.
|
||||
|
||||
Implemented mainly to:
|
||||
- log successful/failed db read/writes
|
||||
- log successful/failed db read/writes
|
||||
- support the max parallel request integration
|
||||
"""
|
||||
|
||||
|
@ -25,15 +28,15 @@ class ProxyLogging:
|
|||
## INITIALIZE LITELLM CALLBACKS ##
|
||||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||
self.max_budget_limiter = MaxBudgetLimiter()
|
||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||
self.max_budget_limiter = MaxBudgetLimiter()
|
||||
pass
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
litellm.callbacks.append(self.max_budget_limiter)
|
||||
for callback in litellm.callbacks:
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
if callback not in litellm.success_callback:
|
||||
|
@ -44,7 +47,7 @@ class ProxyLogging:
|
|||
litellm._async_success_callback.append(callback)
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback)
|
||||
|
||||
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
|
@ -57,31 +60,41 @@ class ProxyLogging:
|
|||
+ litellm.failure_callback
|
||||
)
|
||||
)
|
||||
litellm.utils.set_callbacks(
|
||||
callback_list=callback_list
|
||||
)
|
||||
litellm.utils.set_callbacks(callback_list=callback_list)
|
||||
|
||||
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
async def pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: dict,
|
||||
call_type: Literal["completion", "embeddings"],
|
||||
):
|
||||
"""
|
||||
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
2. /embeddings
|
||||
"""
|
||||
try:
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomLogger) and 'async_pre_call_hook' in vars(callback.__class__):
|
||||
response = await callback.async_pre_call_hook(user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type)
|
||||
if response is not None:
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
|
||||
callback.__class__
|
||||
):
|
||||
response = await callback.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=self.call_details["user_api_key_cache"],
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
)
|
||||
if response is not None:
|
||||
data = response
|
||||
|
||||
print_verbose(f'final data being sent to {call_type} call: {data}')
|
||||
print_verbose(f"final data being sent to {call_type} call: {data}")
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def success_handler(self, *args, **kwargs):
|
||||
|
||||
async def success_handler(self, *args, **kwargs):
|
||||
"""
|
||||
Log successful db read/writes
|
||||
"""
|
||||
|
@ -93,26 +106,31 @@ class ProxyLogging:
|
|||
|
||||
Currently only logs exceptions to sentry
|
||||
"""
|
||||
if litellm.utils.capture_exception:
|
||||
if litellm.utils.capture_exception:
|
||||
litellm.utils.capture_exception(error=original_exception)
|
||||
|
||||
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||
async def post_call_failure_hook(
|
||||
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||
):
|
||||
"""
|
||||
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
2. /embeddings
|
||||
"""
|
||||
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=original_exception)
|
||||
except Exception as e:
|
||||
await callback.async_post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return
|
||||
|
||||
|
||||
|
||||
### DB CONNECTOR ###
|
||||
# Define the retry decorator with backoff strategy
|
||||
|
@ -121,9 +139,12 @@ def on_backoff(details):
|
|||
# The 'tries' key in the details dictionary contains the number of completed tries
|
||||
print_verbose(f"Backing off... this was attempt #{details['tries']}")
|
||||
|
||||
|
||||
class PrismaClient:
|
||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
print_verbose(
|
||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||
)
|
||||
## init logging object
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
|
||||
|
@ -136,23 +157,24 @@ class PrismaClient:
|
|||
os.chdir(dname)
|
||||
|
||||
try:
|
||||
subprocess.run(['prisma', 'generate'])
|
||||
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
||||
subprocess.run(["prisma", "generate"])
|
||||
subprocess.run(
|
||||
["prisma", "db", "push", "--accept-data-loss"]
|
||||
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
# Now you can import the Prisma Client
|
||||
from prisma import Client # type: ignore
|
||||
self.db = Client() #Client to connect to Prisma db
|
||||
from prisma import Client # type: ignore
|
||||
|
||||
|
||||
self.db = Client() # Client to connect to Prisma db
|
||||
|
||||
def hash_token(self, token: str):
|
||||
# Hash the string using SHA-256
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
return hashed_token
|
||||
|
||||
def jsonify_object(self, data: dict) -> dict:
|
||||
def jsonify_object(self, data: dict) -> dict:
|
||||
db_data = copy.deepcopy(data)
|
||||
|
||||
for k, v in db_data.items():
|
||||
|
@ -162,233 +184,258 @@ class PrismaClient:
|
|||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def get_data(self, token: Optional[str]=None, expires: Optional[Any]=None, user_id: Optional[str]=None):
|
||||
try:
|
||||
async def get_data(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
expires: Optional[Any] = None,
|
||||
user_id: Optional[str] = None,
|
||||
):
|
||||
try:
|
||||
response = None
|
||||
if token is not None:
|
||||
if token is not None:
|
||||
# check if plain text or hash
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
if token.startswith("sk-"):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
where={"token": hashed_token}
|
||||
)
|
||||
if response:
|
||||
# Token exists, now check expiration.
|
||||
if response.expires is not None and expires is not None:
|
||||
if response.expires is not None and expires is not None:
|
||||
if response.expires >= expires:
|
||||
# Token exists and is not expired.
|
||||
return response
|
||||
else:
|
||||
# Token exists but is expired.
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="expired user key")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="expired user key",
|
||||
)
|
||||
return response
|
||||
else:
|
||||
# Token does not exist.
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key")
|
||||
elif user_id is not None:
|
||||
response = await self.db.litellm_usertable.find_unique( # type: ignore
|
||||
where={
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid user key",
|
||||
)
|
||||
elif user_id is not None:
|
||||
response = await self.db.litellm_usertable.find_unique( # type: ignore
|
||||
where={
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def insert_data(self, data: dict):
|
||||
"""
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = hashed_token
|
||||
max_budget = db_data.pop("max_budget", None)
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
"token": hashed_token,
|
||||
},
|
||||
data={
|
||||
"create": {**db_data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
"create": {**db_data}, # type: ignore
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
|
||||
new_user_row = await self.db.litellm_usertable.upsert(
|
||||
where={
|
||||
'user_id': data['user_id']
|
||||
},
|
||||
where={"user_id": data["user_id"]},
|
||||
data={
|
||||
"create": {"user_id": data['user_id'], "max_budget": max_budget},
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
"create": {"user_id": data["user_id"], "max_budget": max_budget},
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
return new_verification_token
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def update_data(self, token: Optional[str]=None, data: dict={}, user_id: Optional[str]=None):
|
||||
async def update_data(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
data: dict = {},
|
||||
user_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Update existing data
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
db_data = self.jsonify_object(data=data)
|
||||
if token is not None:
|
||||
if token is not None:
|
||||
print_verbose(f"token: {token}")
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
db_data["token"] = token
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": token # type: ignore
|
||||
},
|
||||
data={**db_data} # type: ignore
|
||||
where={"token": token}, # type: ignore
|
||||
data={**db_data}, # type: ignore
|
||||
)
|
||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||
return {"token": token, "data": db_data}
|
||||
elif user_id is not None:
|
||||
elif user_id is not None:
|
||||
"""
|
||||
If data['spend'] + data['user'], update the user table with spend info as well
|
||||
"""
|
||||
update_user_row = await self.db.litellm_usertable.update(
|
||||
where={
|
||||
'user_id': user_id # type: ignore
|
||||
},
|
||||
data={**db_data} # type: ignore
|
||||
where={"user_id": user_id}, # type: ignore
|
||||
data={**db_data}, # type: ignore
|
||||
)
|
||||
return {"user_id": user_id, "data": db_data}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||
raise e
|
||||
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def delete_data(self, tokens: List):
|
||||
"""
|
||||
Allow user to delete a key(s)
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
||||
await self.db.litellm_verificationtoken.delete_many(
|
||||
where={"token": {"in": hashed_tokens}}
|
||||
)
|
||||
where={"token": {"in": hashed_tokens}}
|
||||
)
|
||||
return {"deleted_keys": tokens}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def connect(self):
|
||||
try:
|
||||
await self.db.connect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def disconnect(self):
|
||||
async def connect(self):
|
||||
try:
|
||||
await self.db.connect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def disconnect(self):
|
||||
try:
|
||||
await self.db.disconnect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
### CUSTOM FILE ###
|
||||
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||
try:
|
||||
print_verbose(f"value: {value}")
|
||||
# Split the path by dots to separate module from instance
|
||||
parts = value.split(".")
|
||||
|
||||
|
||||
# The module path is all but the last part, and the instance_name is the last part
|
||||
module_name = ".".join(parts[:-1])
|
||||
instance_name = parts[-1]
|
||||
|
||||
|
||||
# If config_file_path is provided, use it to determine the module spec and load the module
|
||||
if config_file_path is not None:
|
||||
directory = os.path.dirname(config_file_path)
|
||||
module_file_path = os.path.join(directory, *module_name.split('.'))
|
||||
module_file_path += '.py'
|
||||
module_file_path = os.path.join(directory, *module_name.split("."))
|
||||
module_file_path += ".py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_file_path)
|
||||
if spec is None:
|
||||
raise ImportError(f"Could not find a module specification for {module_file_path}")
|
||||
raise ImportError(
|
||||
f"Could not find a module specification for {module_file_path}"
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
else:
|
||||
# Dynamically import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
|
||||
# Get the instance from the module
|
||||
instance = getattr(module, instance_name)
|
||||
|
||||
|
||||
return instance
|
||||
except ImportError as e:
|
||||
# Re-raise the exception with a user-friendly message
|
||||
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
### HELPER FUNCTIONS ###
|
||||
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
|
||||
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
|
||||
"""
|
||||
Check if a user_id exists in cache,
|
||||
if not retrieve it.
|
||||
Check if a user_id exists in cache,
|
||||
if not retrieve it.
|
||||
"""
|
||||
cache_key = f"{user_id}_user_api_key_user_id"
|
||||
response = cache.get_cache(key=cache_key)
|
||||
if response is None: # Cache miss
|
||||
if response is None: # Cache miss
|
||||
user_row = await db.get_data(user_id=user_id)
|
||||
cache_value = user_row.model_dump_json()
|
||||
cache.set_cache(key=cache_key, value=cache_value, ttl=600) # store for 10 minutes
|
||||
return
|
||||
cache.set_cache(
|
||||
key=cache_key, value=cache_value, ttl=600
|
||||
) # store for 10 minutes
|
||||
return
|
||||
|
|
1241
litellm/router.py
1241
litellm/router.py
File diff suppressed because it is too large
Load diff
|
@ -1,96 +1,121 @@
|
|||
#### What this does ####
|
||||
# identifies least busy deployment
|
||||
# How is this achieved?
|
||||
# How is this achieved?
|
||||
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||
# - use litellm.success + failure callbacks to log when a request completed
|
||||
# - use litellm.success + failure callbacks to log when a request completed
|
||||
# - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||
|
||||
import dotenv, os, requests
|
||||
from typing import Optional
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
class LeastBusyLoggingHandler(CustomLogger):
|
||||
|
||||
class LeastBusyLoggingHandler(CustomLogger):
|
||||
def __init__(self, router_cache: DualCache):
|
||||
self.router_cache = router_cache
|
||||
self.mapping_deployment_to_id: dict = {}
|
||||
|
||||
self.mapping_deployment_to_id: dict = {}
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
"""
|
||||
Log when a model is being used.
|
||||
Log when a model is being used.
|
||||
|
||||
Caching based on model group.
|
||||
Caching based on model group.
|
||||
"""
|
||||
try:
|
||||
|
||||
if kwargs['litellm_params'].get('metadata') is None:
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||
id = kwargs['litellm_params'].get('model_info', {}).get('id', None)
|
||||
else:
|
||||
deployment = kwargs["litellm_params"]["metadata"].get(
|
||||
"deployment", None
|
||||
)
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if deployment is None or model_group is None or id is None:
|
||||
return
|
||||
|
||||
|
||||
# map deployment to id
|
||||
self.mapping_deployment_to_id[deployment] = id
|
||||
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# update cache
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1
|
||||
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_dict[deployment] = (
|
||||
request_count_dict.get(deployment, 0) + 1
|
||||
)
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs['litellm_params'].get('metadata') is None:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||
else:
|
||||
deployment = kwargs["litellm_params"]["metadata"].get(
|
||||
"deployment", None
|
||||
)
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
if deployment is None or model_group is None:
|
||||
return
|
||||
|
||||
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_dict[deployment] = request_count_dict.get(deployment)
|
||||
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs['litellm_params'].get('metadata') is None:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
|
||||
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
|
||||
else:
|
||||
deployment = kwargs["litellm_params"]["metadata"].get(
|
||||
"deployment", None
|
||||
)
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
if deployment is None or model_group is None:
|
||||
return
|
||||
|
||||
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_dict[deployment] = request_count_dict.get(deployment)
|
||||
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def get_available_deployments(self, model_group: str):
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
# map deployment to id
|
||||
return_dict = {}
|
||||
for key, value in request_count_dict.items():
|
||||
return_dict[self.mapping_deployment_to_id[key]] = value
|
||||
return return_dict
|
||||
return return_dict
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import pytest, sys, os
|
||||
import importlib
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
@ -11,24 +12,30 @@ import litellm
|
|||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_teardown():
|
||||
"""
|
||||
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
||||
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
||||
"""
|
||||
curr_dir = os.getcwd() # Get the current working directory
|
||||
sys.path.insert(0, os.path.abspath("../..")) # Adds the project directory to the system path
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the project directory to the system path
|
||||
import litellm
|
||||
|
||||
importlib.reload(litellm)
|
||||
print(litellm)
|
||||
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||
yield
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||
custom_logger_tests = [item for item in items if 'custom_logger' in item.parent.name]
|
||||
other_tests = [item for item in items if 'custom_logger' not in item.parent.name]
|
||||
custom_logger_tests = [
|
||||
item for item in items if "custom_logger" in item.parent.name
|
||||
]
|
||||
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
|
||||
|
||||
# Sort tests based on their names
|
||||
custom_logger_tests.sort(key=lambda x: x.name)
|
||||
other_tests.sort(key=lambda x: x.name)
|
||||
|
||||
# Reorder the items list
|
||||
items[:] = custom_logger_tests + other_tests
|
||||
items[:] = custom_logger_tests + other_tests
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
import sys, os, time
|
||||
import traceback, asyncio
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
@ -11,53 +12,62 @@ import litellm
|
|||
from litellm import Router
|
||||
import concurrent
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
model_list = [{ # list of model deployments
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"tpm": 1000000,
|
||||
"rpm": 9000
|
||||
}
|
||||
model_list = [
|
||||
{ # list of model deployments
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"tpm": 1000000,
|
||||
"rpm": 9000,
|
||||
},
|
||||
]
|
||||
|
||||
kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],}
|
||||
kwargs = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
|
||||
}
|
||||
|
||||
|
||||
def test_multiple_deployments_sync():
|
||||
import concurrent, time
|
||||
litellm.set_verbose=False
|
||||
results = []
|
||||
router = Router(model_list=model_list,
|
||||
redis_host=os.getenv("REDIS_HOST"),
|
||||
redis_password=os.getenv("REDIS_PASSWORD"),
|
||||
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
|
||||
routing_strategy="simple-shuffle",
|
||||
set_verbose=True,
|
||||
num_retries=1) # type: ignore
|
||||
try:
|
||||
for _ in range(3):
|
||||
response = router.completion(**kwargs)
|
||||
results.append(response)
|
||||
print(results)
|
||||
router.reset()
|
||||
except Exception as e:
|
||||
print(f"FAILED TEST!")
|
||||
pytest.fail(f"An error occurred - {traceback.format_exc()}")
|
||||
def test_multiple_deployments_sync():
|
||||
import concurrent, time
|
||||
|
||||
litellm.set_verbose = False
|
||||
results = []
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
redis_host=os.getenv("REDIS_HOST"),
|
||||
redis_password=os.getenv("REDIS_PASSWORD"),
|
||||
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
|
||||
routing_strategy="simple-shuffle",
|
||||
set_verbose=True,
|
||||
num_retries=1,
|
||||
) # type: ignore
|
||||
try:
|
||||
for _ in range(3):
|
||||
response = router.completion(**kwargs)
|
||||
results.append(response)
|
||||
print(results)
|
||||
router.reset()
|
||||
except Exception as e:
|
||||
print(f"FAILED TEST!")
|
||||
pytest.fail(f"An error occurred - {traceback.format_exc()}")
|
||||
|
||||
|
||||
# test_multiple_deployments_sync()
|
||||
|
||||
|
@ -67,13 +77,15 @@ def test_multiple_deployments_parallel():
|
|||
results = []
|
||||
futures = {}
|
||||
start_time = time.time()
|
||||
router = Router(model_list=model_list,
|
||||
redis_host=os.getenv("REDIS_HOST"),
|
||||
redis_password=os.getenv("REDIS_PASSWORD"),
|
||||
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
|
||||
routing_strategy="simple-shuffle",
|
||||
set_verbose=True,
|
||||
num_retries=1) # type: ignore
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
redis_host=os.getenv("REDIS_HOST"),
|
||||
redis_password=os.getenv("REDIS_PASSWORD"),
|
||||
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
|
||||
routing_strategy="simple-shuffle",
|
||||
set_verbose=True,
|
||||
num_retries=1,
|
||||
) # type: ignore
|
||||
# Assuming you have an executor instance defined somewhere in your code
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
for _ in range(5):
|
||||
|
@ -82,7 +94,11 @@ def test_multiple_deployments_parallel():
|
|||
|
||||
# Retrieve the results from the futures
|
||||
while futures:
|
||||
done, not_done = concurrent.futures.wait(futures.values(), timeout=10, return_when=concurrent.futures.FIRST_COMPLETED)
|
||||
done, not_done = concurrent.futures.wait(
|
||||
futures.values(),
|
||||
timeout=10,
|
||||
return_when=concurrent.futures.FIRST_COMPLETED,
|
||||
)
|
||||
for future in done:
|
||||
try:
|
||||
result = future.result()
|
||||
|
@ -98,12 +114,14 @@ def test_multiple_deployments_parallel():
|
|||
print(results)
|
||||
print(f"ELAPSED TIME: {end_time - start_time}")
|
||||
|
||||
|
||||
# Assuming litellm, router, and executor are defined somewhere in your code
|
||||
|
||||
|
||||
# test_multiple_deployments_parallel()
|
||||
def test_cooldown_same_model_name():
|
||||
# users could have the same model with different api_base
|
||||
# example
|
||||
# example
|
||||
# azure/chatgpt, api_base: 1234
|
||||
# azure/chatgpt, api_base: 1235
|
||||
# if 1234 fails, it should only cooldown 1234 and then try with 1235
|
||||
|
@ -118,7 +136,7 @@ def test_cooldown_same_model_name():
|
|||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": "BAD_API_BASE",
|
||||
"tpm": 90
|
||||
"tpm": 90,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -128,7 +146,7 @@ def test_cooldown_same_model_name():
|
|||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"tpm": 0.000001
|
||||
"tpm": 0.000001,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
@ -140,17 +158,12 @@ def test_cooldown_same_model_name():
|
|||
redis_port=int(os.getenv("REDIS_PORT")),
|
||||
routing_strategy="simple-shuffle",
|
||||
set_verbose=True,
|
||||
num_retries=3
|
||||
num_retries=3,
|
||||
) # type: ignore
|
||||
|
||||
response = router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello this request will pass"
|
||||
}
|
||||
]
|
||||
messages=[{"role": "user", "content": "hello this request will pass"}],
|
||||
)
|
||||
print(router.model_list)
|
||||
model_ids = []
|
||||
|
@ -159,10 +172,13 @@ def test_cooldown_same_model_name():
|
|||
print("\n litellm model ids ", model_ids)
|
||||
|
||||
# example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960']
|
||||
assert model_ids[0] != model_ids[1] # ensure both models have a uuid added, and they have different names
|
||||
assert (
|
||||
model_ids[0] != model_ids[1]
|
||||
) # ensure both models have a uuid added, and they have different names
|
||||
|
||||
print("\ngot response\n", response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Got unexpected exception on router! - {e}")
|
||||
|
||||
|
||||
test_cooldown_same_model_name()
|
||||
|
|
|
@ -9,68 +9,70 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
## case 1: set_function_to_prompt not set
|
||||
|
||||
## case 1: set_function_to_prompt not set
|
||||
def test_function_call_non_openai_model():
|
||||
try:
|
||||
try:
|
||||
model = "claude-instant-1"
|
||||
messages=[{"role": "user", "content": "what's the weather in sf?"}]
|
||||
messages = [{"role": "user", "content": "what's the weather in sf?"}]
|
||||
functions = [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
]
|
||||
response = litellm.completion(model=model, messages=messages, functions=functions)
|
||||
pytest.fail(f'An error occurred')
|
||||
except Exception as e:
|
||||
response = litellm.completion(
|
||||
model=model, messages=messages, functions=functions
|
||||
)
|
||||
pytest.fail(f"An error occurred")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
|
||||
test_function_call_non_openai_model()
|
||||
|
||||
## case 2: add_function_to_prompt set
|
||||
|
||||
## case 2: add_function_to_prompt set
|
||||
def test_function_call_non_openai_model_litellm_mod_set():
|
||||
litellm.add_function_to_prompt = True
|
||||
try:
|
||||
try:
|
||||
model = "claude-instant-1"
|
||||
messages=[{"role": "user", "content": "what's the weather in sf?"}]
|
||||
messages = [{"role": "user", "content": "what's the weather in sf?"}]
|
||||
functions = [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
]
|
||||
response = litellm.completion(model=model, messages=messages, functions=functions)
|
||||
print(f'response: {response}')
|
||||
except Exception as e:
|
||||
pytest.fail(f'An error occurred {e}')
|
||||
response = litellm.completion(
|
||||
model=model, messages=messages, functions=functions
|
||||
)
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"An error occurred {e}")
|
||||
|
||||
# test_function_call_non_openai_model_litellm_mod_set()
|
||||
|
||||
# test_function_call_non_openai_model_litellm_mod_set()
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
@ -8,7 +7,7 @@ import os, io
|
|||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest, asyncio
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout, acompletion
|
||||
|
@ -20,18 +19,18 @@ import tempfile
|
|||
litellm.num_retries = 3
|
||||
litellm.cache = None
|
||||
user_message = "Write a short poem about the sky"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
||||
|
||||
def load_vertex_ai_credentials():
|
||||
# Define the path to the vertex_key.json file
|
||||
print("loading vertex ai credentials")
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
vertex_key_path = filepath + '/vertex_key.json'
|
||||
vertex_key_path = filepath + "/vertex_key.json"
|
||||
|
||||
# Read the existing content of the file or create an empty dictionary
|
||||
try:
|
||||
with open(vertex_key_path, 'r') as file:
|
||||
with open(vertex_key_path, "r") as file:
|
||||
# Read the file content
|
||||
print("Read vertexai file path")
|
||||
content = file.read()
|
||||
|
@ -55,13 +54,13 @@ def load_vertex_ai_credentials():
|
|||
service_account_key_data["private_key"] = private_key
|
||||
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file:
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
|
||||
# Write the updated content to the temporary file
|
||||
json.dump(service_account_key_data, temp_file, indent=2)
|
||||
|
||||
|
||||
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def get_response():
|
||||
|
@ -89,43 +88,80 @@ def test_vertex_ai():
|
|||
import random
|
||||
|
||||
load_vertex_ai_credentials()
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
litellm.set_verbose=False
|
||||
test_models = (
|
||||
litellm.vertex_chat_models
|
||||
+ litellm.vertex_code_chat_models
|
||||
+ litellm.vertex_text_models
|
||||
+ litellm.vertex_code_text_models
|
||||
)
|
||||
litellm.set_verbose = False
|
||||
litellm.vertex_project = "reliablekeys"
|
||||
|
||||
test_models = random.sample(test_models, 1)
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
for model in test_models:
|
||||
try:
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
if model in [
|
||||
"code-gecko",
|
||||
"code-gecko@001",
|
||||
"code-gecko@002",
|
||||
"code-gecko@latest",
|
||||
"code-bison@001",
|
||||
"text-bison@001",
|
||||
]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
print("making request", model)
|
||||
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}], temperature=0.7)
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=0.7,
|
||||
)
|
||||
print("\nModel Response", response)
|
||||
print(response)
|
||||
assert type(response.choices[0].message.content) == str
|
||||
assert len(response.choices[0].message.content) > 1
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_vertex_ai()
|
||||
|
||||
|
||||
def test_vertex_ai_stream():
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose=False
|
||||
litellm.set_verbose = False
|
||||
litellm.vertex_project = "reliablekeys"
|
||||
import random
|
||||
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
test_models = (
|
||||
litellm.vertex_chat_models
|
||||
+ litellm.vertex_code_chat_models
|
||||
+ litellm.vertex_text_models
|
||||
+ litellm.vertex_code_text_models
|
||||
)
|
||||
test_models = random.sample(test_models, 1)
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
for model in test_models:
|
||||
try:
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
if model in [
|
||||
"code-gecko",
|
||||
"code-gecko@001",
|
||||
"code-gecko@002",
|
||||
"code-gecko@latest",
|
||||
"code-bison@001",
|
||||
"text-bison@001",
|
||||
]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
print("making request", model)
|
||||
response = completion(model=model, messages=[{"role": "user", "content": "write 10 line code code for saying hi"}], stream=True)
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": "write 10 line code code for saying hi"}
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
completed_str = ""
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
|
@ -137,47 +173,86 @@ def test_vertex_ai_stream():
|
|||
assert len(completed_str) > 4
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_vertex_ai_stream()
|
||||
|
||||
|
||||
# test_vertex_ai_stream()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_vertexai_response():
|
||||
import random
|
||||
|
||||
load_vertex_ai_credentials()
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
test_models = (
|
||||
litellm.vertex_chat_models
|
||||
+ litellm.vertex_code_chat_models
|
||||
+ litellm.vertex_text_models
|
||||
+ litellm.vertex_code_text_models
|
||||
)
|
||||
test_models = random.sample(test_models, 1)
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
for model in test_models:
|
||||
print(f'model being tested in async call: {model}')
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
print(f"model being tested in async call: {model}")
|
||||
if model in [
|
||||
"code-gecko",
|
||||
"code-gecko@001",
|
||||
"code-gecko@002",
|
||||
"code-gecko@latest",
|
||||
"code-bison@001",
|
||||
"text-bison@001",
|
||||
]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
try:
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = await acompletion(model=model, messages=messages, temperature=0.7, timeout=5)
|
||||
response = await acompletion(
|
||||
model=model, messages=messages, temperature=0.7, timeout=5
|
||||
)
|
||||
print(f"response: {response}")
|
||||
except litellm.Timeout as e:
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
|
||||
# asyncio.run(test_async_vertexai_response())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_vertexai_streaming_response():
|
||||
import random
|
||||
|
||||
load_vertex_ai_credentials()
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
test_models = (
|
||||
litellm.vertex_chat_models
|
||||
+ litellm.vertex_code_chat_models
|
||||
+ litellm.vertex_text_models
|
||||
+ litellm.vertex_code_text_models
|
||||
)
|
||||
test_models = random.sample(test_models, 1)
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
for model in test_models:
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
if model in [
|
||||
"code-gecko",
|
||||
"code-gecko@001",
|
||||
"code-gecko@002",
|
||||
"code-gecko@latest",
|
||||
"code-bison@001",
|
||||
"text-bison@001",
|
||||
]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
try:
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5, stream=True)
|
||||
response = await acompletion(
|
||||
model="gemini-pro",
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
timeout=5,
|
||||
stream=True,
|
||||
)
|
||||
print(f"response: {response}")
|
||||
complete_response = ""
|
||||
async for chunk in response:
|
||||
|
@ -185,44 +260,46 @@ async def test_async_vertexai_streaming_response():
|
|||
complete_response += chunk.choices[0].delta.content
|
||||
print(f"complete_response: {complete_response}")
|
||||
assert len(complete_response) > 0
|
||||
except litellm.Timeout as e:
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
|
||||
# asyncio.run(test_async_vertexai_streaming_response())
|
||||
|
||||
|
||||
def test_gemini_pro_vision():
|
||||
try:
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
litellm.num_retries=0
|
||||
litellm.num_retries = 0
|
||||
resp = litellm.completion(
|
||||
model = "vertex_ai/gemini-pro-vision",
|
||||
model="vertex_ai/gemini-pro-vision",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Whats in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
||||
}
|
||||
}
|
||||
]
|
||||
{"type": "text", "text": "Whats in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
print(resp)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
|
||||
# test_gemini_pro_vision()
|
||||
|
||||
|
||||
|
@ -333,4 +410,4 @@ def test_gemini_pro_vision():
|
|||
# import traceback
|
||||
# traceback.print_exc()
|
||||
# raise e
|
||||
# test_gemini_pro_vision_async()
|
||||
# test_gemini_pro_vision_async()
|
||||
|
|
|
@ -11,8 +11,10 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import completion, acompletion, acreate
|
||||
|
||||
litellm.num_retries = 3
|
||||
|
||||
|
||||
def test_sync_response():
|
||||
litellm.set_verbose = False
|
||||
user_message = "Hello, how are you?"
|
||||
|
@ -20,35 +22,49 @@ def test_sync_response():
|
|||
try:
|
||||
response = completion(model="gpt-3.5-turbo", messages=messages, timeout=5)
|
||||
print(f"response: {response}")
|
||||
except litellm.Timeout as e:
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
|
||||
# test_sync_response()
|
||||
|
||||
|
||||
def test_sync_response_anyscale():
|
||||
litellm.set_verbose = False
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
|
||||
except litellm.Timeout as e:
|
||||
response = completion(
|
||||
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
|
||||
messages=messages,
|
||||
timeout=5,
|
||||
)
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
|
||||
# test_sync_response_anyscale()
|
||||
|
||||
|
||||
def test_async_response_openai():
|
||||
import asyncio
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
async def test_get_response():
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages, timeout=5)
|
||||
response = await acompletion(
|
||||
model="gpt-3.5-turbo", messages=messages, timeout=5
|
||||
)
|
||||
print(f"response: {response}")
|
||||
print(f"response ms: {response._response_ms}")
|
||||
except litellm.Timeout as e:
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
@ -56,54 +72,75 @@ def test_async_response_openai():
|
|||
|
||||
asyncio.run(test_get_response())
|
||||
|
||||
|
||||
# test_async_response_openai()
|
||||
|
||||
|
||||
def test_async_response_azure():
|
||||
import asyncio
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
async def test_get_response():
|
||||
user_message = "What do you know?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="azure/gpt-turbo", messages=messages, base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), api_key=os.getenv("AZURE_FRANCE_API_KEY"))
|
||||
response = await acompletion(
|
||||
model="azure/gpt-turbo",
|
||||
messages=messages,
|
||||
base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"),
|
||||
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
|
||||
)
|
||||
print(f"response: {response}")
|
||||
except litellm.Timeout as e:
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
asyncio.run(test_get_response())
|
||||
|
||||
|
||||
# test_async_response_azure()
|
||||
|
||||
|
||||
def test_async_anyscale_response():
|
||||
import asyncio
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
async def test_get_response():
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
|
||||
response = await acompletion(
|
||||
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
|
||||
messages=messages,
|
||||
timeout=5,
|
||||
)
|
||||
# response = await response
|
||||
print(f"response: {response}")
|
||||
except litellm.Timeout as e:
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
asyncio.run(test_get_response())
|
||||
|
||||
|
||||
# test_async_anyscale_response()
|
||||
|
||||
|
||||
def test_get_response_streaming():
|
||||
import asyncio
|
||||
|
||||
async def test_async_call():
|
||||
user_message = "write a short poem in one sentence"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5)
|
||||
response = await acompletion(
|
||||
model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5
|
||||
)
|
||||
print(type(response))
|
||||
|
||||
import inspect
|
||||
|
@ -116,29 +153,39 @@ def test_get_response_streaming():
|
|||
async for chunk in response:
|
||||
token = chunk["choices"][0]["delta"].get("content", "")
|
||||
if token == None:
|
||||
continue # openai v1.0.0 returns content=None
|
||||
continue # openai v1.0.0 returns content=None
|
||||
output += token
|
||||
assert output is not None, "output cannot be None."
|
||||
assert isinstance(output, str), "output needs to be of type str"
|
||||
assert len(output) > 0, "Length of output needs to be greater than 0."
|
||||
print(f'output: {output}')
|
||||
except litellm.Timeout as e:
|
||||
print(f"output: {output}")
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
asyncio.run(test_async_call())
|
||||
|
||||
|
||||
# test_get_response_streaming()
|
||||
|
||||
|
||||
def test_get_response_non_openai_streaming():
|
||||
import asyncio
|
||||
|
||||
litellm.set_verbose = True
|
||||
litellm.num_retries = 0
|
||||
|
||||
async def test_async_call():
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True, timeout=5)
|
||||
response = await acompletion(
|
||||
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
timeout=5,
|
||||
)
|
||||
print(type(response))
|
||||
|
||||
import inspect
|
||||
|
@ -158,11 +205,13 @@ def test_get_response_non_openai_streaming():
|
|||
assert output is not None, "output cannot be None."
|
||||
assert isinstance(output, str), "output needs to be of type str"
|
||||
assert len(output) > 0, "Length of output needs to be greater than 0."
|
||||
except litellm.Timeout as e:
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
return response
|
||||
|
||||
asyncio.run(test_async_call())
|
||||
|
||||
# test_get_response_non_openai_streaming()
|
||||
|
||||
# test_get_response_non_openai_streaming()
|
||||
|
|
|
@ -3,79 +3,105 @@
|
|||
import sys, os, time, inspect, asyncio, traceback
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
import openai, litellm, uuid
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
client = AsyncAzureOpenAI(
|
||||
api_key=os.getenv("AZURE_API_KEY"),
|
||||
azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
|
||||
api_version=os.getenv("AZURE_API_VERSION")
|
||||
azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
|
||||
api_version=os.getenv("AZURE_API_VERSION"),
|
||||
)
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-test",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION")
|
||||
{
|
||||
"model_name": "azure-test",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
router = litellm.Router(model_list=model_list)
|
||||
|
||||
|
||||
async def _openai_completion():
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = await client.chat.completions.create(
|
||||
model="chatgpt-v-2",
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
stream=True
|
||||
)
|
||||
time_to_first_token = None
|
||||
first_token_ts = None
|
||||
init_chunk = None
|
||||
async for chunk in response:
|
||||
if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
|
||||
first_token_ts = time.time()
|
||||
time_to_first_token = first_token_ts - start_time
|
||||
init_chunk = chunk
|
||||
end_time = time.time()
|
||||
print("OpenAI Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time)
|
||||
return time_to_first_token
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = await client.chat.completions.create(
|
||||
model="chatgpt-v-2",
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
stream=True,
|
||||
)
|
||||
time_to_first_token = None
|
||||
first_token_ts = None
|
||||
init_chunk = None
|
||||
async for chunk in response:
|
||||
if (
|
||||
time_to_first_token is None
|
||||
and len(chunk.choices) > 0
|
||||
and chunk.choices[0].delta.content is not None
|
||||
):
|
||||
first_token_ts = time.time()
|
||||
time_to_first_token = first_token_ts - start_time
|
||||
init_chunk = chunk
|
||||
end_time = time.time()
|
||||
print(
|
||||
"OpenAI Call: ",
|
||||
init_chunk,
|
||||
start_time,
|
||||
first_token_ts,
|
||||
time_to_first_token,
|
||||
end_time,
|
||||
)
|
||||
return time_to_first_token
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
async def _router_completion():
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = await router.acompletion(
|
||||
model="azure-test",
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
stream=True
|
||||
)
|
||||
time_to_first_token = None
|
||||
first_token_ts = None
|
||||
init_chunk = None
|
||||
async for chunk in response:
|
||||
if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
|
||||
first_token_ts = time.time()
|
||||
time_to_first_token = first_token_ts - start_time
|
||||
init_chunk = chunk
|
||||
end_time = time.time()
|
||||
print("Router Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time - first_token_ts)
|
||||
return time_to_first_token
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = await router.acompletion(
|
||||
model="azure-test",
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
stream=True,
|
||||
)
|
||||
time_to_first_token = None
|
||||
first_token_ts = None
|
||||
init_chunk = None
|
||||
async for chunk in response:
|
||||
if (
|
||||
time_to_first_token is None
|
||||
and len(chunk.choices) > 0
|
||||
and chunk.choices[0].delta.content is not None
|
||||
):
|
||||
first_token_ts = time.time()
|
||||
time_to_first_token = first_token_ts - start_time
|
||||
init_chunk = chunk
|
||||
end_time = time.time()
|
||||
print(
|
||||
"Router Call: ",
|
||||
init_chunk,
|
||||
start_time,
|
||||
first_token_ts,
|
||||
time_to_first_token,
|
||||
end_time - first_token_ts,
|
||||
)
|
||||
return time_to_first_token
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
async def test_azure_completion_streaming():
|
||||
|
||||
async def test_azure_completion_streaming():
|
||||
"""
|
||||
Test azure streaming call - measure on time to first (non-null) token.
|
||||
Test azure streaming call - measure on time to first (non-null) token.
|
||||
"""
|
||||
n = 3 # Number of concurrent tasks
|
||||
## OPENAI AVG. TIME
|
||||
|
@ -83,19 +109,20 @@ async def test_azure_completion_streaming():
|
|||
chat_completions = await asyncio.gather(*tasks)
|
||||
successful_completions = [c for c in chat_completions if c is not None]
|
||||
total_time = 0
|
||||
for item in successful_completions:
|
||||
total_time += item
|
||||
avg_openai_time = total_time/3
|
||||
for item in successful_completions:
|
||||
total_time += item
|
||||
avg_openai_time = total_time / 3
|
||||
## ROUTER AVG. TIME
|
||||
tasks = [_router_completion() for _ in range(n)]
|
||||
chat_completions = await asyncio.gather(*tasks)
|
||||
successful_completions = [c for c in chat_completions if c is not None]
|
||||
total_time = 0
|
||||
for item in successful_completions:
|
||||
total_time += item
|
||||
avg_router_time = total_time/3
|
||||
for item in successful_completions:
|
||||
total_time += item
|
||||
avg_router_time = total_time / 3
|
||||
## COMPARE
|
||||
print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}")
|
||||
assert avg_router_time < avg_openai_time + 0.5
|
||||
|
||||
# asyncio.run(test_azure_completion_streaming())
|
||||
|
||||
# asyncio.run(test_azure_completion_streaming())
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
@ -18,6 +19,7 @@ user_message = "Hello, how are you?"
|
|||
messages = [{"content": user_message, "role": "user"}]
|
||||
model_val = None
|
||||
|
||||
|
||||
def test_completion_with_no_model():
|
||||
# test on empty
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -32,9 +34,10 @@ def test_completion_with_empty_model():
|
|||
print(f"error occurred: {e}")
|
||||
pass
|
||||
|
||||
|
||||
# def test_completion_catch_nlp_exception():
|
||||
# TEMP commented out NLP cloud API is unstable
|
||||
# try:
|
||||
# try:
|
||||
# response = completion(model="dolphin", messages=messages, functions=[
|
||||
# {
|
||||
# "name": "get_current_weather",
|
||||
|
@ -56,65 +59,77 @@ def test_completion_with_empty_model():
|
|||
# }
|
||||
# ])
|
||||
|
||||
# except Exception as e:
|
||||
# if "Function calling is not supported by nlp_cloud" in str(e):
|
||||
# except Exception as e:
|
||||
# if "Function calling is not supported by nlp_cloud" in str(e):
|
||||
# pass
|
||||
# else:
|
||||
# pytest.fail(f'An error occurred {e}')
|
||||
|
||||
# test_completion_catch_nlp_exception()
|
||||
# test_completion_catch_nlp_exception()
|
||||
|
||||
|
||||
def test_completion_invalid_param_cohere():
|
||||
try:
|
||||
try:
|
||||
response = completion(model="command-nightly", messages=messages, top_p=1)
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
if "Unsupported parameters passed: top_p" in str(e):
|
||||
except Exception as e:
|
||||
if "Unsupported parameters passed: top_p" in str(e):
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f'An error occurred {e}')
|
||||
else:
|
||||
pytest.fail(f"An error occurred {e}")
|
||||
|
||||
|
||||
# test_completion_invalid_param_cohere()
|
||||
|
||||
|
||||
def test_completion_function_call_cohere():
|
||||
try:
|
||||
response = completion(model="command-nightly", messages=messages, functions=["TEST-FUNCTION"])
|
||||
pytest.fail(f'An error occurred {e}')
|
||||
except Exception as e:
|
||||
try:
|
||||
response = completion(
|
||||
model="command-nightly", messages=messages, functions=["TEST-FUNCTION"]
|
||||
)
|
||||
pytest.fail(f"An error occurred {e}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
|
||||
|
||||
# test_completion_function_call_cohere()
|
||||
|
||||
def test_completion_function_call_openai():
|
||||
try:
|
||||
|
||||
def test_completion_function_call_openai():
|
||||
try:
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
response = completion(model="gpt-3.5-turbo", messages=messages, functions=[
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
response = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
functions=[
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
])
|
||||
],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
except:
|
||||
except:
|
||||
pass
|
||||
|
||||
# test_completion_function_call_openai()
|
||||
|
||||
# test_completion_function_call_openai()
|
||||
|
||||
|
||||
def test_completion_with_no_provider():
|
||||
# test on empty
|
||||
|
@ -125,6 +140,7 @@ def test_completion_with_no_provider():
|
|||
print(f"error occurred: {e}")
|
||||
pass
|
||||
|
||||
|
||||
# test_completion_with_no_provider()
|
||||
# # bad key
|
||||
# temp_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
@ -136,4 +152,4 @@ def test_completion_with_no_provider():
|
|||
# except:
|
||||
# print(f"error occurred: {traceback.format_exc()}")
|
||||
# pass
|
||||
# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5
|
||||
# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5
|
||||
|
|
|
@ -4,62 +4,78 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from openai import APITimeoutError as Timeout
|
||||
import litellm
|
||||
|
||||
litellm.num_retries = 0
|
||||
from litellm import batch_completion, batch_completion_models, completion, batch_completion_models_all_responses
|
||||
from litellm import (
|
||||
batch_completion,
|
||||
batch_completion_models,
|
||||
completion,
|
||||
batch_completion_models_all_responses,
|
||||
)
|
||||
|
||||
# litellm.set_verbose=True
|
||||
|
||||
|
||||
def test_batch_completions():
|
||||
messages = [[{"role": "user", "content": "write a short poem"}] for _ in range(3)]
|
||||
model = "j2-mid"
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
result = batch_completion(
|
||||
model=model,
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.2,
|
||||
request_timeout=1
|
||||
request_timeout=1,
|
||||
)
|
||||
print(result)
|
||||
print(len(result))
|
||||
assert(len(result)==3)
|
||||
assert len(result) == 3
|
||||
except Timeout as e:
|
||||
print(f"IN TIMEOUT")
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An error occurred: {e}")
|
||||
|
||||
|
||||
test_batch_completions()
|
||||
|
||||
|
||||
def test_batch_completions_models():
|
||||
try:
|
||||
result = batch_completion_models(
|
||||
models=["gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-3.5-turbo"],
|
||||
messages=[{"role": "user", "content": "Hey, how's it going"}]
|
||||
models=["gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-3.5-turbo"],
|
||||
messages=[{"role": "user", "content": "Hey, how's it going"}],
|
||||
)
|
||||
print(result)
|
||||
except Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An error occurred: {e}")
|
||||
|
||||
|
||||
# test_batch_completions_models()
|
||||
|
||||
|
||||
def test_batch_completion_models_all_responses():
|
||||
try:
|
||||
responses = batch_completion_models_all_responses(
|
||||
models=["j2-light", "claude-instant-1.2"],
|
||||
models=["j2-light", "claude-instant-1.2"],
|
||||
messages=[{"role": "user", "content": "write a poem"}],
|
||||
max_tokens=10
|
||||
max_tokens=10,
|
||||
)
|
||||
print(responses)
|
||||
assert(len(responses) == 2)
|
||||
assert len(responses) == 2
|
||||
except Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An error occurred: {e}")
|
||||
# test_batch_completion_models_all_responses()
|
||||
|
||||
|
||||
# test_batch_completion_models_all_responses()
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
|
||||
# import sys, os, json
|
||||
# import traceback
|
||||
# import pytest
|
||||
# import pytest
|
||||
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import litellm
|
||||
# import litellm
|
||||
# litellm.set_verbose = True
|
||||
# from litellm import completion, BudgetManager
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
|||
|
||||
# ## Scenario 1: User budget enough to make call
|
||||
# def test_user_budget_enough():
|
||||
# try:
|
||||
# try:
|
||||
# user = "1234"
|
||||
# # create a budget for a user
|
||||
# budget_manager.create_budget(total_budget=10, user=user, duration="daily")
|
||||
|
@ -38,7 +38,7 @@
|
|||
|
||||
# ## Scenario 2: User budget not enough to make call
|
||||
# def test_user_budget_not_enough():
|
||||
# try:
|
||||
# try:
|
||||
# user = "12345"
|
||||
# # create a budget for a user
|
||||
# budget_manager.create_budget(total_budget=0, user=user, duration="daily")
|
||||
|
@ -60,7 +60,7 @@
|
|||
# except:
|
||||
# pytest.fail(f"An error occurred")
|
||||
|
||||
# ## Scenario 3: Saving budget to client
|
||||
# ## Scenario 3: Saving budget to client
|
||||
# def test_save_user_budget():
|
||||
# try:
|
||||
# response = budget_manager.save_data()
|
||||
|
@ -70,17 +70,17 @@
|
|||
# except Exception as e:
|
||||
# pytest.fail(f"An error occurred: {str(e)}")
|
||||
|
||||
# test_save_user_budget()
|
||||
# ## Scenario 4: Getting list of users
|
||||
# test_save_user_budget()
|
||||
# ## Scenario 4: Getting list of users
|
||||
# def test_get_users():
|
||||
# try:
|
||||
# response = budget_manager.get_users()
|
||||
# print(response)
|
||||
# except:
|
||||
# pytest.fail(f"An error occurred")
|
||||
# pytest.fail(f"An error occurred")
|
||||
|
||||
|
||||
# ## Scenario 5: Reset budget at the end of duration
|
||||
# ## Scenario 5: Reset budget at the end of duration
|
||||
# def test_reset_on_duration():
|
||||
# try:
|
||||
# # First, set a short duration budget for a user
|
||||
|
@ -100,7 +100,7 @@
|
|||
|
||||
# # Now, we need to simulate the passing of time. Since we don't want our tests to actually take days, we're going
|
||||
# # to cheat a little -- we'll manually adjust the "created_at" time so it seems like a day has passed.
|
||||
# # In a real-world testing scenario, we might instead use something like the `freezegun` library to mock the system time.
|
||||
# # In a real-world testing scenario, we might instead use something like the `freezegun` library to mock the system time.
|
||||
# one_day_in_seconds = 24 * 60 * 60
|
||||
# budget_manager.user_dict[user]["last_updated_at"] -= one_day_in_seconds
|
||||
|
||||
|
@ -108,11 +108,11 @@
|
|||
# budget_manager.update_budget_all_users()
|
||||
|
||||
# # Make sure the budget was actually reset
|
||||
# assert budget_manager.get_current_cost(user) == 0, "Budget didn't reset after duration expired"
|
||||
# assert budget_manager.get_current_cost(user) == 0, "Budget didn't reset after duration expired"
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"An error occurred - {str(e)}")
|
||||
|
||||
# ## Scenario 6: passing in text:
|
||||
# ## Scenario 6: passing in text:
|
||||
# def test_input_text_on_completion():
|
||||
# try:
|
||||
# user = "12345"
|
||||
|
@ -127,4 +127,4 @@
|
|||
# except Exception as e:
|
||||
# pytest.fail(f"An error occurred - {str(e)}")
|
||||
|
||||
# test_input_text_on_completion()
|
||||
# test_input_text_on_completion()
|
||||
|
|
|
@ -14,6 +14,7 @@ import litellm
|
|||
from litellm import embedding, completion
|
||||
from litellm.caching import Cache
|
||||
import random
|
||||
|
||||
# litellm.set_verbose=True
|
||||
|
||||
messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
||||
|
@ -22,23 +23,30 @@ messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
|||
import random
|
||||
import string
|
||||
|
||||
|
||||
def generate_random_word(length=4):
|
||||
letters = string.ascii_lowercase
|
||||
return ''.join(random.choice(letters) for _ in range(length))
|
||||
return "".join(random.choice(letters) for _ in range(length))
|
||||
|
||||
|
||||
messages = [{"role": "user", "content": "who is ishaan 5222"}]
|
||||
def test_caching_v2(): # test in memory cache
|
||||
|
||||
|
||||
def test_caching_v2(): # test in memory cache
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose = True
|
||||
litellm.cache = Cache()
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
litellm.cache = None # disable cache
|
||||
litellm.cache = None # disable cache
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
|
||||
if (
|
||||
response2["choices"][0]["message"]["content"]
|
||||
!= response1["choices"][0]["message"]["content"]
|
||||
):
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
|
@ -46,12 +54,14 @@ def test_caching_v2(): # test in memory cache
|
|||
print(f"error occurred: {traceback.format_exc()}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_caching_v2()
|
||||
|
||||
|
||||
|
||||
def test_caching_with_models_v2():
|
||||
messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}]
|
||||
messages = [
|
||||
{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}
|
||||
]
|
||||
litellm.cache = Cache()
|
||||
print("test2 for caching")
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||
|
@ -63,34 +73,51 @@ def test_caching_with_models_v2():
|
|||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
|
||||
if (
|
||||
response3["choices"][0]["message"]["content"]
|
||||
== response2["choices"][0]["message"]["content"]
|
||||
):
|
||||
# if models are different, it should not return cached response
|
||||
print(f"response2: {response2}")
|
||||
print(f"response3: {response3}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']:
|
||||
if (
|
||||
response1["choices"][0]["message"]["content"]
|
||||
!= response2["choices"][0]["message"]["content"]
|
||||
):
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
|
||||
|
||||
# test_caching_with_models_v2()
|
||||
|
||||
embedding_large_text = """
|
||||
embedding_large_text = (
|
||||
"""
|
||||
small text
|
||||
""" * 5
|
||||
"""
|
||||
* 5
|
||||
)
|
||||
|
||||
|
||||
# # test_caching_with_models()
|
||||
def test_embedding_caching():
|
||||
import time
|
||||
|
||||
litellm.cache = Cache()
|
||||
text_to_embed = [embedding_large_text]
|
||||
start_time = time.time()
|
||||
embedding1 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True)
|
||||
embedding1 = embedding(
|
||||
model="text-embedding-ada-002", input=text_to_embed, caching=True
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"Embedding 1 response time: {end_time - start_time} seconds")
|
||||
|
||||
time.sleep(1)
|
||||
start_time = time.time()
|
||||
embedding2 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True)
|
||||
embedding2 = embedding(
|
||||
model="text-embedding-ada-002", input=text_to_embed, caching=True
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"embedding2: {embedding2}")
|
||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||
|
@ -98,29 +125,30 @@ def test_embedding_caching():
|
|||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
||||
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
|
||||
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
||||
if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]:
|
||||
print(f"embedding1: {embedding1}")
|
||||
print(f"embedding2: {embedding2}")
|
||||
pytest.fail("Error occurred: Embedding caching failed")
|
||||
|
||||
|
||||
# test_embedding_caching()
|
||||
|
||||
|
||||
def test_embedding_caching_azure():
|
||||
print("Testing azure embedding caching")
|
||||
import time
|
||||
|
||||
litellm.cache = Cache()
|
||||
text_to_embed = [embedding_large_text]
|
||||
|
||||
api_key = os.environ['AZURE_API_KEY']
|
||||
api_base = os.environ['AZURE_API_BASE']
|
||||
api_version = os.environ['AZURE_API_VERSION']
|
||||
|
||||
os.environ['AZURE_API_VERSION'] = ""
|
||||
os.environ['AZURE_API_BASE'] = ""
|
||||
os.environ['AZURE_API_KEY'] = ""
|
||||
api_key = os.environ["AZURE_API_KEY"]
|
||||
api_base = os.environ["AZURE_API_BASE"]
|
||||
api_version = os.environ["AZURE_API_VERSION"]
|
||||
|
||||
os.environ["AZURE_API_VERSION"] = ""
|
||||
os.environ["AZURE_API_BASE"] = ""
|
||||
os.environ["AZURE_API_KEY"] = ""
|
||||
|
||||
start_time = time.time()
|
||||
print("AZURE CONFIGS")
|
||||
|
@ -133,7 +161,7 @@ def test_embedding_caching_azure():
|
|||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
caching=True
|
||||
caching=True,
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"Embedding 1 response time: {end_time - start_time} seconds")
|
||||
|
@ -146,7 +174,7 @@ def test_embedding_caching_azure():
|
|||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
caching=True
|
||||
caching=True,
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||
|
@ -154,15 +182,16 @@ def test_embedding_caching_azure():
|
|||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
||||
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
|
||||
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
||||
if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]:
|
||||
print(f"embedding1: {embedding1}")
|
||||
print(f"embedding2: {embedding2}")
|
||||
pytest.fail("Error occurred: Embedding caching failed")
|
||||
|
||||
os.environ['AZURE_API_VERSION'] = api_version
|
||||
os.environ['AZURE_API_BASE'] = api_base
|
||||
os.environ['AZURE_API_KEY'] = api_key
|
||||
os.environ["AZURE_API_VERSION"] = api_version
|
||||
os.environ["AZURE_API_BASE"] = api_base
|
||||
os.environ["AZURE_API_KEY"] = api_key
|
||||
|
||||
|
||||
# test_embedding_caching_azure()
|
||||
|
||||
|
@ -170,13 +199,28 @@ def test_embedding_caching_azure():
|
|||
def test_redis_cache_completion():
|
||||
litellm.set_verbose = False
|
||||
|
||||
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
|
||||
]
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
print("test2 for caching")
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20)
|
||||
response3 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5)
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
|
||||
)
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
|
||||
)
|
||||
response3 = completion(
|
||||
model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5
|
||||
)
|
||||
response4 = completion(model="command-nightly", messages=messages, caching=True)
|
||||
|
||||
print("\nresponse 1", response1)
|
||||
|
@ -192,49 +236,88 @@ def test_redis_cache_completion():
|
|||
1 & 3 should be different, since input params are diff
|
||||
1 & 4 should be diff, since models are diff
|
||||
"""
|
||||
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: # 1 and 2 should be the same
|
||||
if (
|
||||
response1["choices"][0]["message"]["content"]
|
||||
!= response2["choices"][0]["message"]["content"]
|
||||
): # 1 and 2 should be the same
|
||||
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
if response1['choices'][0]['message']['content'] == response3['choices'][0]['message']['content']:
|
||||
if (
|
||||
response1["choices"][0]["message"]["content"]
|
||||
== response3["choices"][0]["message"]["content"]
|
||||
):
|
||||
# if input params like seed, max_tokens are diff it should NOT be a cache hit
|
||||
print(f"response1: {response1}")
|
||||
print(f"response3: {response3}")
|
||||
pytest.fail(f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:")
|
||||
if response1['choices'][0]['message']['content'] == response4['choices'][0]['message']['content']:
|
||||
pytest.fail(
|
||||
f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:"
|
||||
)
|
||||
if (
|
||||
response1["choices"][0]["message"]["content"]
|
||||
== response4["choices"][0]["message"]["content"]
|
||||
):
|
||||
# if models are different, it should not return cached response
|
||||
print(f"response1: {response1}")
|
||||
print(f"response4: {response4}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
|
||||
|
||||
# test_redis_cache_completion()
|
||||
|
||||
|
||||
def test_redis_cache_completion_stream():
|
||||
try:
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm.callbacks = []
|
||||
litellm.set_verbose = True
|
||||
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_number}",
|
||||
}
|
||||
]
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
print("test for caching, streaming + completion")
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
max_tokens=40,
|
||||
temperature=0.2,
|
||||
stream=True,
|
||||
)
|
||||
response_1_content = ""
|
||||
for chunk in response1:
|
||||
print(chunk)
|
||||
response_1_content += chunk.choices[0].delta.content or ""
|
||||
print(response_1_content)
|
||||
time.sleep(0.5)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
max_tokens=40,
|
||||
temperature=0.2,
|
||||
stream=True,
|
||||
)
|
||||
response_2_content = ""
|
||||
for chunk in response2:
|
||||
print(chunk)
|
||||
response_2_content += chunk.choices[0].delta.content or ""
|
||||
print("\nresponse 1", response_1_content)
|
||||
print("\nresponse 2", response_2_content)
|
||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
assert (
|
||||
response_1_content == response_2_content
|
||||
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
litellm.success_callback = []
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
|
@ -247,99 +330,171 @@ def test_redis_cache_completion_stream():
|
|||
|
||||
1 & 2 should be exactly the same
|
||||
"""
|
||||
|
||||
|
||||
# test_redis_cache_completion_stream()
|
||||
|
||||
|
||||
def test_redis_cache_acompletion_stream():
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
random_word = generate_random_word()
|
||||
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_word}",
|
||||
}
|
||||
]
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
print("test for caching, streaming + completion")
|
||||
response_1_content = ""
|
||||
response_2_content = ""
|
||||
|
||||
async def call1():
|
||||
nonlocal response_1_content
|
||||
response1 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||
nonlocal response_1_content
|
||||
response1 = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
max_tokens=40,
|
||||
temperature=1,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response1:
|
||||
print(chunk)
|
||||
response_1_content += chunk.choices[0].delta.content or ""
|
||||
print(response_1_content)
|
||||
|
||||
asyncio.run(call1())
|
||||
time.sleep(0.5)
|
||||
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
||||
|
||||
async def call2():
|
||||
nonlocal response_2_content
|
||||
response2 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||
response2 = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
max_tokens=40,
|
||||
temperature=1,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response2:
|
||||
print(chunk)
|
||||
response_2_content += chunk.choices[0].delta.content or ""
|
||||
print(response_2_content)
|
||||
|
||||
asyncio.run(call2())
|
||||
print("\nresponse 1", response_1_content)
|
||||
print("\nresponse 2", response_2_content)
|
||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
assert (
|
||||
response_1_content == response_2_content
|
||||
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
|
||||
# test_redis_cache_acompletion_stream()
|
||||
|
||||
|
||||
def test_redis_cache_acompletion_stream_bedrock():
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
random_word = generate_random_word()
|
||||
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_word}",
|
||||
}
|
||||
]
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
print("test for caching, streaming + completion")
|
||||
response_1_content = ""
|
||||
response_2_content = ""
|
||||
|
||||
async def call1():
|
||||
nonlocal response_1_content
|
||||
response1 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||
nonlocal response_1_content
|
||||
response1 = await litellm.acompletion(
|
||||
model="bedrock/anthropic.claude-v1",
|
||||
messages=messages,
|
||||
max_tokens=40,
|
||||
temperature=1,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response1:
|
||||
print(chunk)
|
||||
response_1_content += chunk.choices[0].delta.content or ""
|
||||
print(response_1_content)
|
||||
|
||||
asyncio.run(call1())
|
||||
time.sleep(0.5)
|
||||
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
||||
|
||||
async def call2():
|
||||
nonlocal response_2_content
|
||||
response2 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||
response2 = await litellm.acompletion(
|
||||
model="bedrock/anthropic.claude-v1",
|
||||
messages=messages,
|
||||
max_tokens=40,
|
||||
temperature=1,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response2:
|
||||
print(chunk)
|
||||
response_2_content += chunk.choices[0].delta.content or ""
|
||||
print(response_2_content)
|
||||
|
||||
asyncio.run(call2())
|
||||
print("\nresponse 1", response_1_content)
|
||||
print("\nresponse 2", response_2_content)
|
||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
assert (
|
||||
response_1_content == response_2_content
|
||||
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
|
||||
# test_redis_cache_acompletion_stream_bedrock()
|
||||
# redis cache with custom keys
|
||||
def custom_get_cache_key(*args, **kwargs):
|
||||
# return key to use for your cache:
|
||||
key = kwargs.get("model", "") + str(kwargs.get("messages", "")) + str(kwargs.get("temperature", "")) + str(kwargs.get("logit_bias", ""))
|
||||
# return key to use for your cache:
|
||||
key = (
|
||||
kwargs.get("model", "")
|
||||
+ str(kwargs.get("messages", ""))
|
||||
+ str(kwargs.get("temperature", ""))
|
||||
+ str(kwargs.get("logit_bias", ""))
|
||||
)
|
||||
return key
|
||||
|
||||
|
||||
def test_custom_redis_cache_with_key():
|
||||
messages = [{"role": "user", "content": "write a one line story"}]
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
litellm.cache.get_cache_key = custom_get_cache_key
|
||||
|
||||
local_cache = {}
|
||||
|
@ -356,54 +511,72 @@ def test_custom_redis_cache_with_key():
|
|||
|
||||
# patch this redis cache get and set call
|
||||
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3)
|
||||
response3 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=False, num_retries=3)
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
temperature=1,
|
||||
caching=True,
|
||||
num_retries=3,
|
||||
)
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
temperature=1,
|
||||
caching=True,
|
||||
num_retries=3,
|
||||
)
|
||||
response3 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
temperature=1,
|
||||
caching=False,
|
||||
num_retries=3,
|
||||
)
|
||||
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
print(f"response3: {response3}")
|
||||
|
||||
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
|
||||
if (
|
||||
response3["choices"][0]["message"]["content"]
|
||||
== response2["choices"][0]["message"]["content"]
|
||||
):
|
||||
pytest.fail(f"Error occurred:")
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
||||
|
||||
# test_custom_redis_cache_with_key()
|
||||
|
||||
|
||||
def test_cache_override():
|
||||
# test if we can override the cache, when `caching=False` but litellm.cache = Cache() is set
|
||||
# in this case it should not return cached responses
|
||||
# in this case it should not return cached responses
|
||||
litellm.cache = Cache()
|
||||
print("Testing cache override")
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose = True
|
||||
|
||||
# test embedding
|
||||
response1 = embedding(
|
||||
model = "text-embedding-ada-002",
|
||||
input=[
|
||||
"hello who are you"
|
||||
],
|
||||
caching = False
|
||||
model="text-embedding-ada-002", input=["hello who are you"], caching=False
|
||||
)
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
response2 = embedding(
|
||||
model = "text-embedding-ada-002",
|
||||
input=[
|
||||
"hello who are you"
|
||||
],
|
||||
caching = False
|
||||
model="text-embedding-ada-002", input=["hello who are you"], caching=False
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||
|
||||
assert end_time - start_time > 0.1 # ensure 2nd response comes in over 0.1s. This should not be cached.
|
||||
# test_cache_override()
|
||||
assert (
|
||||
end_time - start_time > 0.1
|
||||
) # ensure 2nd response comes in over 0.1s. This should not be cached.
|
||||
|
||||
|
||||
# test_cache_override()
|
||||
|
||||
|
||||
def test_custom_redis_cache_params():
|
||||
|
@ -411,17 +584,17 @@ def test_custom_redis_cache_params():
|
|||
try:
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ['REDIS_HOST'],
|
||||
port=os.environ['REDIS_PORT'],
|
||||
password=os.environ['REDIS_PASSWORD'],
|
||||
db = 0,
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
db=0,
|
||||
ssl=True,
|
||||
ssl_certfile="./redis_user.crt",
|
||||
ssl_keyfile="./redis_user_private.key",
|
||||
ssl_ca_certs="./redis_ca.pem",
|
||||
)
|
||||
|
||||
print(litellm.cache.cache.redis_client)
|
||||
print(litellm.cache.cache.redis_client)
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
@ -431,58 +604,126 @@ def test_custom_redis_cache_params():
|
|||
|
||||
def test_get_cache_key():
|
||||
from litellm.caching import Cache
|
||||
|
||||
try:
|
||||
print("Testing get_cache_key")
|
||||
cache_instance = Cache()
|
||||
cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
|
||||
cache_key = cache_instance.get_cache_key(
|
||||
**{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "write a one sentence poem about: 7510"}
|
||||
],
|
||||
"max_tokens": 40,
|
||||
"temperature": 0.2,
|
||||
"stream": True,
|
||||
"litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b",
|
||||
"litellm_logging_obj": {},
|
||||
}
|
||||
)
|
||||
cache_key_2 = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
|
||||
cache_key_2 = cache_instance.get_cache_key(
|
||||
**{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "write a one sentence poem about: 7510"}
|
||||
],
|
||||
"max_tokens": 40,
|
||||
"temperature": 0.2,
|
||||
"stream": True,
|
||||
"litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b",
|
||||
"litellm_logging_obj": {},
|
||||
}
|
||||
)
|
||||
assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
|
||||
assert cache_key == cache_key_2, f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs"
|
||||
assert (
|
||||
cache_key
|
||||
== "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
|
||||
)
|
||||
assert (
|
||||
cache_key == cache_key_2
|
||||
), f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs"
|
||||
|
||||
embedding_cache_key = cache_instance.get_cache_key(
|
||||
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
|
||||
'api_key': '', 'api_version': '2023-07-01-preview',
|
||||
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
|
||||
'caching': True,
|
||||
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>"
|
||||
**{
|
||||
"model": "azure/azure-embedding-model",
|
||||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
|
||||
"api_key": "",
|
||||
"api_version": "2023-07-01-preview",
|
||||
"timeout": None,
|
||||
"max_retries": 0,
|
||||
"input": ["hi who is ishaan"],
|
||||
"caching": True,
|
||||
"client": "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
|
||||
}
|
||||
)
|
||||
|
||||
print(embedding_cache_key)
|
||||
|
||||
assert embedding_cache_key == "model: azure/azure-embedding-modelinput: ['hi who is ishaan']", f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs"
|
||||
assert (
|
||||
embedding_cache_key
|
||||
== "model: azure/azure-embedding-modelinput: ['hi who is ishaan']"
|
||||
), f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs"
|
||||
|
||||
# Proxy - embedding cache, test if embedding key, gets model_group and not model
|
||||
embedding_cache_key_2 = cache_instance.get_cache_key(
|
||||
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
|
||||
'api_key': '', 'api_version': '2023-07-01-preview',
|
||||
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
|
||||
'caching': True,
|
||||
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
|
||||
'proxy_server_request': {'url': 'http://0.0.0.0:8000/embeddings',
|
||||
'method': 'POST',
|
||||
'headers':
|
||||
{'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json',
|
||||
'content-length': '80'},
|
||||
'body': {'model': 'azure-embedding-model', 'input': ['hi who is ishaan']}},
|
||||
'user': None,
|
||||
'metadata': {'user_api_key': None,
|
||||
'headers': {'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', 'content-length': '80'},
|
||||
'model_group': 'EMBEDDING_MODEL_GROUP',
|
||||
'deployment': 'azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview'},
|
||||
'model_info': {'mode': 'embedding', 'base_model': 'text-embedding-ada-002', 'id': '20b2b515-f151-4dd5-a74f-2231e2f54e29'},
|
||||
'litellm_call_id': '2642e009-b3cd-443d-b5dd-bb7d56123b0e', 'litellm_logging_obj': '<litellm.utils.Logging object at 0x12f1bddb0>'}
|
||||
**{
|
||||
"model": "azure/azure-embedding-model",
|
||||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
|
||||
"api_key": "",
|
||||
"api_version": "2023-07-01-preview",
|
||||
"timeout": None,
|
||||
"max_retries": 0,
|
||||
"input": ["hi who is ishaan"],
|
||||
"caching": True,
|
||||
"client": "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
|
||||
"proxy_server_request": {
|
||||
"url": "http://0.0.0.0:8000/embeddings",
|
||||
"method": "POST",
|
||||
"headers": {
|
||||
"host": "0.0.0.0:8000",
|
||||
"user-agent": "curl/7.88.1",
|
||||
"accept": "*/*",
|
||||
"content-type": "application/json",
|
||||
"content-length": "80",
|
||||
},
|
||||
"body": {
|
||||
"model": "azure-embedding-model",
|
||||
"input": ["hi who is ishaan"],
|
||||
},
|
||||
},
|
||||
"user": None,
|
||||
"metadata": {
|
||||
"user_api_key": None,
|
||||
"headers": {
|
||||
"host": "0.0.0.0:8000",
|
||||
"user-agent": "curl/7.88.1",
|
||||
"accept": "*/*",
|
||||
"content-type": "application/json",
|
||||
"content-length": "80",
|
||||
},
|
||||
"model_group": "EMBEDDING_MODEL_GROUP",
|
||||
"deployment": "azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview",
|
||||
},
|
||||
"model_info": {
|
||||
"mode": "embedding",
|
||||
"base_model": "text-embedding-ada-002",
|
||||
"id": "20b2b515-f151-4dd5-a74f-2231e2f54e29",
|
||||
},
|
||||
"litellm_call_id": "2642e009-b3cd-443d-b5dd-bb7d56123b0e",
|
||||
"litellm_logging_obj": "<litellm.utils.Logging object at 0x12f1bddb0>",
|
||||
}
|
||||
)
|
||||
|
||||
print(embedding_cache_key_2)
|
||||
assert embedding_cache_key_2 == "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']"
|
||||
assert (
|
||||
embedding_cache_key_2
|
||||
== "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']"
|
||||
)
|
||||
print("passed!")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred:", e)
|
||||
|
||||
|
||||
test_get_cache_key()
|
||||
|
||||
# test_custom_redis_cache_params()
|
||||
|
@ -581,4 +822,4 @@ test_get_cache_key()
|
|||
# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
|
||||
# time.sleep(2)
|
||||
# assert cache.get_cache(cache_key="test_key") is None
|
||||
# # test_in_memory_cache_with_ttl()
|
||||
# # test_in_memory_cache_with_ttl()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#### What this tests ####
|
||||
# This tests using caching w/ litellm which requires SSL=True
|
||||
# This tests using caching w/ litellm which requires SSL=True
|
||||
|
||||
import sys, os
|
||||
import time
|
||||
|
@ -18,15 +18,26 @@ from litellm import embedding, completion, Router
|
|||
from litellm.caching import Cache
|
||||
|
||||
messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}]
|
||||
def test_caching_v2(): # test in memory cache
|
||||
|
||||
|
||||
def test_caching_v2(): # test in memory cache
|
||||
try:
|
||||
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2")
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host="os.environ/REDIS_HOST_2",
|
||||
port="os.environ/REDIS_PORT_2",
|
||||
password="os.environ/REDIS_PASSWORD_2",
|
||||
ssl="os.environ/REDIS_SSL_2",
|
||||
)
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
litellm.cache = None # disable cache
|
||||
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
|
||||
litellm.cache = None # disable cache
|
||||
if (
|
||||
response2["choices"][0]["message"]["content"]
|
||||
!= response1["choices"][0]["message"]["content"]
|
||||
):
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
raise Exception()
|
||||
|
@ -34,41 +45,57 @@ def test_caching_v2(): # test in memory cache
|
|||
print(f"error occurred: {traceback.format_exc()}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_caching_v2()
|
||||
|
||||
|
||||
def test_caching_router():
|
||||
"""
|
||||
Test scenario where litellm.cache is set but kwargs("caching") is not. This should still return a cache hit.
|
||||
Test scenario where litellm.cache is set but kwargs("caching") is not. This should still return a cache hit.
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
}
|
||||
]
|
||||
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2")
|
||||
router = Router(model_list=model_list,
|
||||
routing_strategy="simple-shuffle",
|
||||
set_verbose=False,
|
||||
num_retries=1) # type: ignore
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
}
|
||||
]
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host="os.environ/REDIS_HOST_2",
|
||||
port="os.environ/REDIS_PORT_2",
|
||||
password="os.environ/REDIS_PASSWORD_2",
|
||||
ssl="os.environ/REDIS_SSL_2",
|
||||
)
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="simple-shuffle",
|
||||
set_verbose=False,
|
||||
num_retries=1,
|
||||
) # type: ignore
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages)
|
||||
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
|
||||
if (
|
||||
response2["choices"][0]["message"]["content"]
|
||||
!= response1["choices"][0]["message"]["content"]
|
||||
):
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
litellm.cache = None # disable cache
|
||||
assert response2['choices'][0]['message']['content'] == response1['choices'][0]['message']['content']
|
||||
litellm.cache = None # disable cache
|
||||
assert (
|
||||
response2["choices"][0]["message"]["content"]
|
||||
== response1["choices"][0]["message"]["content"]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"error occurred: {traceback.format_exc()}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_caching_router()
|
||||
|
||||
# test_caching_router()
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import litellm
|
||||
# import asyncio
|
||||
# import asyncio
|
||||
|
||||
# litellm.set_verbose = True
|
||||
# from litellm import Router
|
||||
|
@ -18,9 +18,9 @@
|
|||
# # This enables response_model keyword
|
||||
# # # from client.chat.completions.create
|
||||
# # client = instructor.patch(Router(model_list=[{
|
||||
# # "model_name": "gpt-3.5-turbo", # openai model name
|
||||
# # "litellm_params": { # params for litellm completion/embedding call
|
||||
# # "model": "azure/chatgpt-v-2",
|
||||
# # "model_name": "gpt-3.5-turbo", # openai model name
|
||||
# # "litellm_params": { # params for litellm completion/embedding call
|
||||
# # "model": "azure/chatgpt-v-2",
|
||||
# # "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# # "api_version": os.getenv("AZURE_API_VERSION"),
|
||||
# # "api_base": os.getenv("AZURE_API_BASE")
|
||||
|
@ -49,9 +49,9 @@
|
|||
# from openai import AsyncOpenAI
|
||||
|
||||
# aclient = instructor.apatch(Router(model_list=[{
|
||||
# "model_name": "gpt-3.5-turbo", # openai model name
|
||||
# "litellm_params": { # params for litellm completion/embedding call
|
||||
# "model": "azure/chatgpt-v-2",
|
||||
# "model_name": "gpt-3.5-turbo", # openai model name
|
||||
# "litellm_params": { # params for litellm completion/embedding call
|
||||
# "model": "azure/chatgpt-v-2",
|
||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# "api_version": os.getenv("AZURE_API_VERSION"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE")
|
||||
|
@ -71,4 +71,4 @@
|
|||
# )
|
||||
# print(f"model: {model}")
|
||||
|
||||
# asyncio.run(main())
|
||||
# asyncio.run(main())
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -28,6 +28,7 @@ def logger_fn(user_model_dict):
|
|||
# print(f"user_model_dict: {user_model_dict}")
|
||||
pass
|
||||
|
||||
|
||||
# normal call
|
||||
def test_completion_custom_provider_model_name():
|
||||
try:
|
||||
|
@ -41,25 +42,31 @@ def test_completion_custom_provider_model_name():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# completion with num retries + impact on exception mapping
|
||||
def test_completion_with_num_retries():
|
||||
try:
|
||||
response = completion(model="j2-ultra", messages=[{"messages": "vibe", "bad": "message"}], num_retries=2)
|
||||
def test_completion_with_num_retries():
|
||||
try:
|
||||
response = completion(
|
||||
model="j2-ultra",
|
||||
messages=[{"messages": "vibe", "bad": "message"}],
|
||||
num_retries=2,
|
||||
)
|
||||
pytest.fail(f"Unmapped exception occurred")
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
# test_completion_with_num_retries()
|
||||
def test_completion_with_0_num_retries():
|
||||
try:
|
||||
litellm.set_verbose=False
|
||||
litellm.set_verbose = False
|
||||
print("making request")
|
||||
|
||||
# Use the completion function
|
||||
response = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"gm": "vibe", "role": "user"}],
|
||||
max_retries=4
|
||||
max_retries=4,
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
@ -69,5 +76,6 @@ def test_completion_with_0_num_retries():
|
|||
print("exception", e)
|
||||
pass
|
||||
|
||||
|
||||
# Call the test function
|
||||
test_completion_with_0_num_retries()
|
||||
test_completion_with_0_num_retries()
|
||||
|
|
|
@ -15,77 +15,104 @@ from litellm import completion_with_config
|
|||
config = {
|
||||
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
|
||||
"model": {
|
||||
"claude-instant-1": {
|
||||
"needs_moderation": True
|
||||
},
|
||||
"claude-instant-1": {"needs_moderation": True},
|
||||
"gpt-3.5-turbo": {
|
||||
"error_handling": {
|
||||
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
|
||||
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_config_context_window_exceeded():
|
||||
try:
|
||||
sample_text = "how does a court case get to the Supreme Court?" * 1000
|
||||
messages = [{"content": sample_text, "role": "user"}]
|
||||
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config)
|
||||
response = completion_with_config(
|
||||
model="gpt-3.5-turbo", messages=messages, config=config
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
# test_config_context_window_exceeded()
|
||||
|
||||
# test_config_context_window_exceeded()
|
||||
|
||||
|
||||
def test_config_context_moderation():
|
||||
try:
|
||||
messages=[{"role": "user", "content": "I want to kill them."}]
|
||||
response = completion_with_config(model="claude-instant-1", messages=messages, config=config)
|
||||
messages = [{"role": "user", "content": "I want to kill them."}]
|
||||
response = completion_with_config(
|
||||
model="claude-instant-1", messages=messages, config=config
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
# test_config_context_moderation()
|
||||
|
||||
# test_config_context_moderation()
|
||||
|
||||
|
||||
def test_config_context_default_fallback():
|
||||
try:
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||
response = completion_with_config(model="claude-instant-1", messages=messages, config=config, api_key="bad-key")
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
response = completion_with_config(
|
||||
model="claude-instant-1",
|
||||
messages=messages,
|
||||
config=config,
|
||||
api_key="bad-key",
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
# test_config_context_default_fallback()
|
||||
|
||||
# test_config_context_default_fallback()
|
||||
|
||||
|
||||
config = {
|
||||
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
|
||||
"available_models": ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-0314", "gpt-4-0613",
|
||||
"j2-ultra", "command-nightly", "togethercomputer/llama-2-70b-chat", "chat-bison", "chat-bison@001", "claude-2"],
|
||||
"adapt_to_prompt_size": True, # type: ignore
|
||||
"available_models": [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-0613",
|
||||
"j2-ultra",
|
||||
"command-nightly",
|
||||
"togethercomputer/llama-2-70b-chat",
|
||||
"chat-bison",
|
||||
"chat-bison@001",
|
||||
"claude-2",
|
||||
],
|
||||
"adapt_to_prompt_size": True, # type: ignore
|
||||
"model": {
|
||||
"claude-instant-1": {
|
||||
"needs_moderation": True
|
||||
},
|
||||
"claude-instant-1": {"needs_moderation": True},
|
||||
"gpt-3.5-turbo": {
|
||||
"error_handling": {
|
||||
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
|
||||
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_config_context_adapt_to_prompt():
|
||||
try:
|
||||
sample_text = "how does a court case get to the Supreme Court?" * 1000
|
||||
messages = [{"content": sample_text, "role": "user"}]
|
||||
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config)
|
||||
response = completion_with_config(
|
||||
model="gpt-3.5-turbo", messages=messages, config=config
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
test_config_context_adapt_to_prompt()
|
||||
|
||||
test_config_context_adapt_to_prompt()
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from fastapi import Request
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
print(f"api_key: {api_key}")
|
||||
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
|
||||
return UserAPIKeyAuth(api_key=api_key)
|
||||
raise Exception
|
||||
except:
|
||||
raise Exception
|
||||
except:
|
||||
raise Exception
|
||||
|
|
|
@ -2,30 +2,35 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
import inspect
|
||||
import litellm
|
||||
|
||||
|
||||
class testCustomCallbackProxy(CustomLogger):
|
||||
def __init__(self):
|
||||
self.success: bool = False # type: ignore
|
||||
self.failure: bool = False # type: ignore
|
||||
self.async_success: bool = False # type: ignore
|
||||
self.success: bool = False # type: ignore
|
||||
self.failure: bool = False # type: ignore
|
||||
self.async_success: bool = False # type: ignore
|
||||
self.async_success_embedding: bool = False # type: ignore
|
||||
self.async_failure: bool = False # type: ignore
|
||||
self.async_failure: bool = False # type: ignore
|
||||
self.async_failure_embedding: bool = False # type: ignore
|
||||
|
||||
self.async_completion_kwargs = None # type: ignore
|
||||
self.async_embedding_kwargs = None # type: ignore
|
||||
self.async_embedding_response = None # type: ignore
|
||||
self.async_completion_kwargs = None # type: ignore
|
||||
self.async_embedding_kwargs = None # type: ignore
|
||||
self.async_embedding_response = None # type: ignore
|
||||
|
||||
self.async_completion_kwargs_fail = None # type: ignore
|
||||
self.async_embedding_kwargs_fail = None # type: ignore
|
||||
self.async_completion_kwargs_fail = None # type: ignore
|
||||
self.async_embedding_kwargs_fail = None # type: ignore
|
||||
|
||||
self.streaming_response_obj = None # type: ignore
|
||||
self.streaming_response_obj = None # type: ignore
|
||||
blue_color_code = "\033[94m"
|
||||
reset_color_code = "\033[0m"
|
||||
print(f"{blue_color_code}Initialized LiteLLM custom logger")
|
||||
try:
|
||||
print(f"Logger Initialized with following methods:")
|
||||
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
|
||||
|
||||
methods = [
|
||||
method
|
||||
for method in dir(self)
|
||||
if inspect.ismethod(getattr(self, method))
|
||||
]
|
||||
|
||||
# Pretty print the methods
|
||||
for method in methods:
|
||||
print(f" - {method}")
|
||||
|
@ -33,29 +38,32 @@ class testCustomCallbackProxy(CustomLogger):
|
|||
except:
|
||||
pass
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print(f"Pre-API Call")
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"Post-API Call")
|
||||
|
||||
|
||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Stream")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Success")
|
||||
self.success = True
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Failure")
|
||||
self.failure = True
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async success")
|
||||
self.async_success = True
|
||||
print("Value of async success: ", self.async_success)
|
||||
print("\n kwargs: ", kwargs)
|
||||
if kwargs.get("model") == "azure-embedding-model" or kwargs.get("model") == "ada":
|
||||
if (
|
||||
kwargs.get("model") == "azure-embedding-model"
|
||||
or kwargs.get("model") == "ada"
|
||||
):
|
||||
print("Got an embedding model", kwargs.get("model"))
|
||||
print("Setting embedding success to True")
|
||||
self.async_success_embedding = True
|
||||
|
@ -65,7 +73,6 @@ class testCustomCallbackProxy(CustomLogger):
|
|||
if kwargs.get("stream") == True:
|
||||
self.streaming_response_obj = response_obj
|
||||
|
||||
|
||||
self.async_completion_kwargs = kwargs
|
||||
|
||||
model = kwargs.get("model", None)
|
||||
|
@ -74,17 +81,18 @@ class testCustomCallbackProxy(CustomLogger):
|
|||
|
||||
# Access litellm_params passed to litellm.completion(), example access `metadata`
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here
|
||||
metadata = litellm_params.get(
|
||||
"metadata", {}
|
||||
) # headers passed to LiteLLM proxy, can be found here
|
||||
|
||||
# Calculate cost using litellm.completion_cost()
|
||||
cost = litellm.completion_cost(completion_response=response_obj)
|
||||
response = response_obj
|
||||
# tokens used in response
|
||||
# tokens used in response
|
||||
usage = response_obj["usage"]
|
||||
|
||||
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
|
||||
|
||||
|
||||
print(
|
||||
f"""
|
||||
Model: {model},
|
||||
|
@ -98,8 +106,7 @@ class testCustomCallbackProxy(CustomLogger):
|
|||
)
|
||||
return
|
||||
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Failure")
|
||||
self.async_failure = True
|
||||
print("Value of async failure: ", self.async_failure)
|
||||
|
@ -107,7 +114,8 @@ class testCustomCallbackProxy(CustomLogger):
|
|||
if kwargs.get("model") == "text-embedding-ada-002":
|
||||
self.async_failure_embedding = True
|
||||
self.async_embedding_kwargs_fail = kwargs
|
||||
|
||||
|
||||
self.async_completion_kwargs_fail = kwargs
|
||||
|
||||
my_custom_logger = testCustomCallbackProxy()
|
||||
|
||||
my_custom_logger = testCustomCallbackProxy()
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -3,7 +3,8 @@
|
|||
import sys, os, time, inspect, asyncio, traceback
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
from typing import Optional, Literal, List
|
||||
from litellm import Router, Cache
|
||||
import litellm
|
||||
|
@ -14,206 +15,274 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
## 2: Post-API-Call
|
||||
## 3: On LiteLLM Call success
|
||||
## 4: On LiteLLM Call failure
|
||||
## fallbacks
|
||||
## retries
|
||||
## fallbacks
|
||||
## retries
|
||||
|
||||
# Test cases
|
||||
## 1. Simple Azure OpenAI acompletion + streaming call
|
||||
## 2. Simple Azure OpenAI aembedding call
|
||||
# Test cases
|
||||
## 1. Simple Azure OpenAI acompletion + streaming call
|
||||
## 2. Simple Azure OpenAI aembedding call
|
||||
## 3. Azure OpenAI acompletion + streaming call with retries
|
||||
## 4. Azure OpenAI aembedding call with retries
|
||||
## 5. Azure OpenAI acompletion + streaming call with fallbacks
|
||||
## 6. Azure OpenAI aembedding call with fallbacks
|
||||
|
||||
# Test interfaces
|
||||
## 1. router.completion() + router.embeddings()
|
||||
## 2. proxy.completions + proxy.embeddings
|
||||
## 1. router.completion() + router.embeddings()
|
||||
## 2. proxy.completions + proxy.embeddings
|
||||
|
||||
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
|
||||
class CompletionCustomHandler(
|
||||
CustomLogger
|
||||
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
"""
|
||||
The set of expected inputs to a custom handler for a
|
||||
The set of expected inputs to a custom handler for a
|
||||
"""
|
||||
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
self.errors = []
|
||||
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
|
||||
self.states: Optional[
|
||||
List[
|
||||
Literal[
|
||||
"sync_pre_api_call",
|
||||
"async_pre_api_call",
|
||||
"post_api_call",
|
||||
"sync_stream",
|
||||
"async_stream",
|
||||
"sync_success",
|
||||
"async_success",
|
||||
"sync_failure",
|
||||
"async_failure",
|
||||
]
|
||||
]
|
||||
] = []
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
print(f'received kwargs in pre-input: {kwargs}')
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
print(f"received kwargs in pre-input: {kwargs}")
|
||||
self.states.append("sync_pre_api_call")
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
### ROUTER-SPECIFIC KWARGS
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||
assert isinstance(
|
||||
kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
|
||||
)
|
||||
assert isinstance(
|
||||
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
|
||||
)
|
||||
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("post_api_call")
|
||||
## START TIME
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
## END TIME
|
||||
assert end_time == None
|
||||
## RESPONSE OBJECT
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert isinstance(kwargs["input"], (list, dict, str))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(
|
||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
||||
)
|
||||
or inspect.iscoroutine(kwargs["original_response"])
|
||||
or inspect.isasyncgen(kwargs["original_response"])
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
### ROUTER-SPECIFIC KWARGS
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||
assert isinstance(
|
||||
kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
|
||||
)
|
||||
assert isinstance(
|
||||
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
|
||||
)
|
||||
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||
except:
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_stream")
|
||||
## START TIME
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list) and isinstance(
|
||||
kwargs["messages"][0], dict
|
||||
)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(kwargs["input"], list)
|
||||
and isinstance(kwargs["input"][0], dict)
|
||||
) or isinstance(kwargs["input"], (dict, str))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(
|
||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
||||
)
|
||||
or inspect.isasyncgen(kwargs["original_response"])
|
||||
or inspect.iscoroutine(kwargs["original_response"])
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_success")
|
||||
## START TIME
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list) and isinstance(
|
||||
kwargs["messages"][0], dict
|
||||
)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(kwargs["input"], list)
|
||||
and isinstance(kwargs["input"][0], dict)
|
||||
) or isinstance(kwargs["input"], (dict, str))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert isinstance(
|
||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_failure")
|
||||
## START TIME
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list) and isinstance(
|
||||
kwargs["messages"][0], dict
|
||||
)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(kwargs["input"], list)
|
||||
and isinstance(kwargs["input"][0], dict)
|
||||
) or isinstance(kwargs["input"], (dict, str))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(
|
||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
||||
)
|
||||
or kwargs["original_response"] == None
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
try:
|
||||
"""
|
||||
No-op.
|
||||
Not implemented yet.
|
||||
No-op.
|
||||
Not implemented yet.
|
||||
"""
|
||||
pass
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
try:
|
||||
self.states.append("async_success")
|
||||
## START TIME
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(
|
||||
response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)
|
||||
)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert isinstance(kwargs["input"], (list, dict, str))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(
|
||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
||||
)
|
||||
or inspect.isasyncgen(kwargs["original_response"])
|
||||
or inspect.iscoroutine(kwargs["original_response"])
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||
### ROUTER-SPECIFIC KWARGS
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||
|
@ -221,10 +290,14 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||
assert isinstance(
|
||||
kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
|
||||
)
|
||||
assert isinstance(
|
||||
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
|
||||
)
|
||||
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||
except:
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
|
@ -232,257 +305,281 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
try:
|
||||
print(f"received original response: {kwargs['original_response']}")
|
||||
self.states.append("async_failure")
|
||||
## START TIME
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, str, dict))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert isinstance(kwargs["input"], (list, str, dict))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(
|
||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
||||
)
|
||||
or inspect.isasyncgen(kwargs["original_response"])
|
||||
or inspect.iscoroutine(kwargs["original_response"])
|
||||
or kwargs["original_response"] == None
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
# Simple Azure OpenAI call
|
||||
|
||||
# Simple Azure OpenAI call
|
||||
## COMPLETION
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_azure():
|
||||
try:
|
||||
try:
|
||||
customHandler_completion_azure_router = CompletionCustomHandler()
|
||||
customHandler_streaming_azure_router = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_completion_azure_router]
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
response = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
assert len(customHandler_completion_azure_router.errors) == 0
|
||||
assert len(customHandler_completion_azure_router.states) == 3 # pre, post, success
|
||||
# streaming
|
||||
assert (
|
||||
len(customHandler_completion_azure_router.states) == 3
|
||||
) # pre, post, success
|
||||
# streaming
|
||||
litellm.callbacks = [customHandler_streaming_azure_router]
|
||||
router2 = Router(model_list=model_list) # type: ignore
|
||||
response = await router2.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
router2 = Router(model_list=model_list) # type: ignore
|
||||
response = await router2.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response:
|
||||
print(f"async azure router chunk: {chunk}")
|
||||
continue
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
|
||||
assert len(customHandler_streaming_azure_router.errors) == 0
|
||||
assert len(customHandler_streaming_azure_router.states) >= 4 # pre, post, stream (multiple times), success
|
||||
# failure
|
||||
assert (
|
||||
len(customHandler_streaming_azure_router.states) >= 4
|
||||
) # pre, post, stream (multiple times), success
|
||||
# failure
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
]
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
router3 = Router(model_list=model_list) # type: ignore
|
||||
try:
|
||||
response = await router3.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
router3 = Router(model_list=model_list) # type: ignore
|
||||
try:
|
||||
response = await router3.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||
)
|
||||
print(f"response in router3 acompletion: {response}")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||
assert "async_failure" in customHandler_failure.states
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
# asyncio.run(test_async_chat_azure())
|
||||
## EMBEDDING
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_azure():
|
||||
try:
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-embedding-model", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/azure-embedding-model",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
response = await router.aembedding(model="azure-embedding-model",
|
||||
input=["hello from litellm!"])
|
||||
{
|
||||
"model_name": "azure-embedding-model", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/azure-embedding-model",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
response = await router.aembedding(
|
||||
model="azure-embedding-model", input=["hello from litellm!"]
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
assert len(customHandler.errors) == 0
|
||||
assert len(customHandler.states) == 3 # pre, post, success
|
||||
# failure
|
||||
assert len(customHandler.states) == 3 # pre, post, success
|
||||
# failure
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-embedding-model", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/azure-embedding-model",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
{
|
||||
"model_name": "azure-embedding-model", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/azure-embedding-model",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
]
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
router3 = Router(model_list=model_list) # type: ignore
|
||||
try:
|
||||
response = await router3.aembedding(model="azure-embedding-model",
|
||||
input=["hello from litellm!"])
|
||||
router3 = Router(model_list=model_list) # type: ignore
|
||||
try:
|
||||
response = await router3.aembedding(
|
||||
model="azure-embedding-model", input=["hello from litellm!"]
|
||||
)
|
||||
print(f"response in router3 aembedding: {response}")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||
assert "async_failure" in customHandler_failure.states
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
# asyncio.run(test_async_embedding_azure())
|
||||
# Azure OpenAI call w/ Fallbacks
|
||||
## COMPLETION
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_azure_with_fallbacks():
|
||||
try:
|
||||
async def test_async_chat_azure_with_fallbacks():
|
||||
try:
|
||||
customHandler_fallbacks = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_fallbacks]
|
||||
# with fallbacks
|
||||
# with fallbacks
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-16k",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
}
|
||||
]
|
||||
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
|
||||
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-16k",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
|
||||
response = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}")
|
||||
assert len(customHandler_fallbacks.errors) == 0
|
||||
assert len(customHandler_fallbacks.states) == 6 # pre, post, failure, pre, post, success
|
||||
assert (
|
||||
len(customHandler_fallbacks.states) == 6
|
||||
) # pre, post, failure, pre, post, success
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
# asyncio.run(test_async_chat_azure_with_fallbacks())
|
||||
|
||||
# CACHING
|
||||
|
||||
# CACHING
|
||||
## Test Azure - completion, embedding
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_completion_azure_caching():
|
||||
customHandler_caching = CompletionCustomHandler()
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
litellm.callbacks = [customHandler_caching]
|
||||
unique_time = time.time()
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-16k",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
}
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
response1 = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-16k",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
response1 = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
|
||||
caching=True,
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||
response2 = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
|
||||
response2 = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
|
||||
caching=True,
|
||||
)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
print(
|
||||
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
|
||||
)
|
||||
assert len(customHandler_caching.errors) == 0
|
||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import sys
|
||||
import os
|
||||
import io, asyncio
|
||||
|
||||
# import logging
|
||||
# logging.basicConfig(level=logging.DEBUG)
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
from litellm import completion
|
||||
import litellm
|
||||
|
||||
litellm.num_retries = 3
|
||||
|
||||
import time, random
|
||||
|
@ -29,11 +31,14 @@ def pre_request():
|
|||
|
||||
|
||||
import re
|
||||
def verify_log_file(log_file_path):
|
||||
|
||||
with open(log_file_path, 'r') as log_file:
|
||||
|
||||
def verify_log_file(log_file_path):
|
||||
with open(log_file_path, "r") as log_file:
|
||||
log_content = log_file.read()
|
||||
print(f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content)
|
||||
print(
|
||||
f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content
|
||||
)
|
||||
|
||||
# Define the pattern to search for in the log file
|
||||
pattern = r"Response from DynamoDB:{.*?}"
|
||||
|
@ -50,17 +55,21 @@ def verify_log_file(log_file_path):
|
|||
print(f"Total occurrences of specified response: {len(matches)}")
|
||||
|
||||
# Count the occurrences of successful responses (status code 200 or 201)
|
||||
success_count = sum(1 for match in matches if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match)
|
||||
success_count = sum(
|
||||
1
|
||||
for match in matches
|
||||
if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match
|
||||
)
|
||||
|
||||
# Print the count of successful responses
|
||||
print(f"Count of successful responses from DynamoDB: {success_count}")
|
||||
assert success_count == 3 # Expect 3 success logs from dynamoDB
|
||||
assert success_count == 3 # Expect 3 success logs from dynamoDB
|
||||
|
||||
|
||||
def test_dynamo_logging():
|
||||
def test_dynamo_logging():
|
||||
# all dynamodb requests need to be in one test function
|
||||
# since we are modifying stdout, and pytests runs tests in parallel
|
||||
try:
|
||||
try:
|
||||
# pre
|
||||
# redirect stdout to log_file
|
||||
|
||||
|
@ -69,44 +78,44 @@ def test_dynamo_logging():
|
|||
litellm.set_verbose = True
|
||||
original_stdout, log_file, file_name = pre_request()
|
||||
|
||||
|
||||
print("Testing async dynamoDB logging")
|
||||
|
||||
async def _test():
|
||||
return await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content":"This is a test"}],
|
||||
messages=[{"role": "user", "content": "This is a test"}],
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
user = "ishaan-2"
|
||||
user="ishaan-2",
|
||||
)
|
||||
|
||||
response = asyncio.run(_test())
|
||||
print(f"response: {response}")
|
||||
|
||||
|
||||
# streaming + async
|
||||
# streaming + async
|
||||
async def _test2():
|
||||
response = await litellm.acompletion(
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content":"This is a test"}],
|
||||
messages=[{"role": "user", "content": "This is a test"}],
|
||||
max_tokens=10,
|
||||
temperature=0.7,
|
||||
user = "ishaan-2",
|
||||
stream=True
|
||||
user="ishaan-2",
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response:
|
||||
pass
|
||||
|
||||
asyncio.run(_test2())
|
||||
|
||||
# aembedding()
|
||||
async def _test3():
|
||||
return await litellm.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input = ["hi"],
|
||||
user = "ishaan-2"
|
||||
model="text-embedding-ada-002", input=["hi"], user="ishaan-2"
|
||||
)
|
||||
|
||||
response = asyncio.run(_test3())
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {e}")
|
||||
finally:
|
||||
# post, close log file and verify
|
||||
|
@ -117,4 +126,5 @@ def test_dynamo_logging():
|
|||
verify_log_file(file_name)
|
||||
print("Passed! Testing async dynamoDB logging")
|
||||
|
||||
|
||||
# test_dynamo_logging_async()
|
||||
|
|
|
@ -14,39 +14,49 @@ from litellm import embedding, completion
|
|||
|
||||
litellm.set_verbose = False
|
||||
|
||||
|
||||
def test_openai_embedding():
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose = True
|
||||
response = embedding(
|
||||
model="text-embedding-ada-002",
|
||||
input=["good morning from litellm", "this is another item"],
|
||||
metadata = {"anything": "good day"}
|
||||
model="text-embedding-ada-002",
|
||||
input=["good morning from litellm", "this is another item"],
|
||||
metadata={"anything": "good day"},
|
||||
)
|
||||
litellm_response = dict(response)
|
||||
litellm_response_keys = set(litellm_response.keys())
|
||||
litellm_response_keys.discard('_response_ms')
|
||||
litellm_response_keys.discard("_response_ms")
|
||||
|
||||
print(litellm_response_keys)
|
||||
print("LiteLLM Response\n")
|
||||
# print(litellm_response)
|
||||
|
||||
# same request with OpenAI 1.0+
|
||||
|
||||
# same request with OpenAI 1.0+
|
||||
import openai
|
||||
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'])
|
||||
|
||||
client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-ada-002", input=["good morning from litellm", "this is another item"]
|
||||
model="text-embedding-ada-002",
|
||||
input=["good morning from litellm", "this is another item"],
|
||||
)
|
||||
|
||||
response = dict(response)
|
||||
openai_response_keys = set(response.keys())
|
||||
print(openai_response_keys)
|
||||
assert litellm_response_keys == openai_response_keys # ENSURE the Keys in litellm response is exactly what the openai package returns
|
||||
assert len(litellm_response["data"]) == 2 # expect two embedding responses from litellm_response since input had two
|
||||
assert (
|
||||
litellm_response_keys == openai_response_keys
|
||||
) # ENSURE the Keys in litellm response is exactly what the openai package returns
|
||||
assert (
|
||||
len(litellm_response["data"]) == 2
|
||||
) # expect two embedding responses from litellm_response since input had two
|
||||
print(openai_response_keys)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_openai_embedding()
|
||||
|
||||
|
||||
def test_openai_azure_embedding_simple():
|
||||
try:
|
||||
response = embedding(
|
||||
|
@ -55,12 +65,15 @@ def test_openai_azure_embedding_simple():
|
|||
)
|
||||
print(response)
|
||||
response_keys = set(dict(response).keys())
|
||||
response_keys.discard('_response_ms')
|
||||
assert set(["usage", "model", "object", "data"]) == set(response_keys) #assert litellm response has expected keys from OpenAI embedding response
|
||||
response_keys.discard("_response_ms")
|
||||
assert set(["usage", "model", "object", "data"]) == set(
|
||||
response_keys
|
||||
) # assert litellm response has expected keys from OpenAI embedding response
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_openai_azure_embedding_simple()
|
||||
|
||||
|
||||
|
@ -69,41 +82,50 @@ def test_openai_azure_embedding_timeouts():
|
|||
response = embedding(
|
||||
model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"],
|
||||
timeout=0.00001
|
||||
timeout=0.00001,
|
||||
)
|
||||
print(response)
|
||||
except openai.APITimeoutError:
|
||||
print("Good job got timeout error!")
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}")
|
||||
pytest.fail(
|
||||
f"Expected timeout error, did not get the correct error. Instead got {e}"
|
||||
)
|
||||
|
||||
|
||||
# test_openai_azure_embedding_timeouts()
|
||||
|
||||
|
||||
def test_openai_embedding_timeouts():
|
||||
try:
|
||||
response = embedding(
|
||||
model="text-embedding-ada-002",
|
||||
input=["good morning from litellm"],
|
||||
timeout=0.00001
|
||||
timeout=0.00001,
|
||||
)
|
||||
print(response)
|
||||
except openai.APITimeoutError:
|
||||
print("Good job got OpenAI timeout error!")
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}")
|
||||
pytest.fail(
|
||||
f"Expected timeout error, did not get the correct error. Instead got {e}"
|
||||
)
|
||||
|
||||
|
||||
# test_openai_embedding_timeouts()
|
||||
|
||||
|
||||
def test_openai_azure_embedding():
|
||||
try:
|
||||
api_key = os.environ['AZURE_API_KEY']
|
||||
api_base = os.environ['AZURE_API_BASE']
|
||||
api_version = os.environ['AZURE_API_VERSION']
|
||||
api_key = os.environ["AZURE_API_KEY"]
|
||||
api_base = os.environ["AZURE_API_BASE"]
|
||||
api_version = os.environ["AZURE_API_VERSION"]
|
||||
|
||||
os.environ['AZURE_API_VERSION'] = ""
|
||||
os.environ['AZURE_API_BASE'] = ""
|
||||
os.environ['AZURE_API_KEY'] = ""
|
||||
os.environ["AZURE_API_VERSION"] = ""
|
||||
os.environ["AZURE_API_BASE"] = ""
|
||||
os.environ["AZURE_API_KEY"] = ""
|
||||
|
||||
response = embedding(
|
||||
model="azure/azure-embedding-model",
|
||||
|
@ -114,137 +136,179 @@ def test_openai_azure_embedding():
|
|||
)
|
||||
print(response)
|
||||
|
||||
|
||||
os.environ['AZURE_API_VERSION'] = api_version
|
||||
os.environ['AZURE_API_BASE'] = api_base
|
||||
os.environ['AZURE_API_KEY'] = api_key
|
||||
os.environ["AZURE_API_VERSION"] = api_version
|
||||
os.environ["AZURE_API_BASE"] = api_base
|
||||
os.environ["AZURE_API_KEY"] = api_key
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_openai_azure_embedding()
|
||||
|
||||
# test_openai_embedding()
|
||||
|
||||
|
||||
def test_cohere_embedding():
|
||||
try:
|
||||
# litellm.set_verbose=True
|
||||
response = embedding(
|
||||
model="embed-english-v2.0", input=["good morning from litellm", "this is another item"]
|
||||
model="embed-english-v2.0",
|
||||
input=["good morning from litellm", "this is another item"],
|
||||
)
|
||||
print(f"response:", response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_cohere_embedding()
|
||||
|
||||
|
||||
def test_cohere_embedding3():
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose = True
|
||||
response = embedding(
|
||||
model="embed-english-v3.0",
|
||||
input=["good morning from litellm", "this is another item"],
|
||||
model="embed-english-v3.0",
|
||||
input=["good morning from litellm", "this is another item"],
|
||||
)
|
||||
print(f"response:", response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_cohere_embedding3()
|
||||
|
||||
|
||||
def test_bedrock_embedding_titan():
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose = True
|
||||
response = embedding(
|
||||
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
|
||||
"lets test a second string for good measure"]
|
||||
model="amazon.titan-embed-text-v1",
|
||||
input=[
|
||||
"good morning from litellm, attempting to embed data",
|
||||
"lets test a second string for good measure",
|
||||
],
|
||||
)
|
||||
print(f"response:", response)
|
||||
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list"
|
||||
print(f"type of first embedding:", type(response['data'][0]['embedding'][0]))
|
||||
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
|
||||
assert isinstance(
|
||||
response["data"][0]["embedding"], list
|
||||
), "Expected response to be a list"
|
||||
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
|
||||
assert all(
|
||||
isinstance(x, float) for x in response["data"][0]["embedding"]
|
||||
), "Expected response to be a list of floats"
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
test_bedrock_embedding_titan()
|
||||
|
||||
|
||||
def test_bedrock_embedding_cohere():
|
||||
try:
|
||||
litellm.set_verbose=False
|
||||
litellm.set_verbose = False
|
||||
response = embedding(
|
||||
model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"],
|
||||
aws_region_name="os.environ/AWS_REGION_NAME_2"
|
||||
model="cohere.embed-multilingual-v3",
|
||||
input=[
|
||||
"good morning from litellm, attempting to embed data",
|
||||
"lets test a second string for good measure",
|
||||
],
|
||||
aws_region_name="os.environ/AWS_REGION_NAME_2",
|
||||
)
|
||||
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list"
|
||||
print(f"type of first embedding:", type(response['data'][0]['embedding'][0]))
|
||||
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
|
||||
assert isinstance(
|
||||
response["data"][0]["embedding"], list
|
||||
), "Expected response to be a list"
|
||||
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
|
||||
assert all(
|
||||
isinstance(x, float) for x in response["data"][0]["embedding"]
|
||||
), "Expected response to be a list of floats"
|
||||
# print(f"response:", response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_bedrock_embedding_cohere()
|
||||
|
||||
|
||||
# comment out hf tests - since hf endpoints are unstable
|
||||
def test_hf_embedding():
|
||||
try:
|
||||
# huggingface/microsoft/codebert-base
|
||||
# huggingface/facebook/bart-large
|
||||
response = embedding(
|
||||
model="huggingface/sentence-transformers/all-MiniLM-L6-v2", input=["good morning from litellm", "this is another item"]
|
||||
model="huggingface/sentence-transformers/all-MiniLM-L6-v2",
|
||||
input=["good morning from litellm", "this is another item"],
|
||||
)
|
||||
print(f"response:", response)
|
||||
except Exception as e:
|
||||
# Note: Huggingface inference API is unstable and fails with "model loading errors all the time"
|
||||
pass
|
||||
|
||||
|
||||
# test_hf_embedding()
|
||||
|
||||
|
||||
# test async embeddings
|
||||
def test_aembedding():
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
async def embedding_call():
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input=["good morning from litellm", "this is another item"]
|
||||
model="text-embedding-ada-002",
|
||||
input=["good morning from litellm", "this is another item"],
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
asyncio.run(embedding_call())
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_aembedding()
|
||||
|
||||
|
||||
def test_aembedding_azure():
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
async def embedding_call():
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm", "this is another item"]
|
||||
model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm", "this is another item"],
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
asyncio.run(embedding_call())
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_aembedding_azure()
|
||||
|
||||
def test_sagemaker_embeddings():
|
||||
try:
|
||||
response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"])
|
||||
|
||||
def test_sagemaker_embeddings():
|
||||
try:
|
||||
response = litellm.embedding(
|
||||
model="sagemaker/berri-benchmarking-gpt-j-6b-fp16",
|
||||
input=["good morning from litellm", "this is another item"],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_sagemaker_embeddings()
|
||||
# def local_proxy_embeddings():
|
||||
# litellm.set_verbose=True
|
||||
# response = embedding(
|
||||
# model="openai/custom_embedding",
|
||||
# model="openai/custom_embedding",
|
||||
# input=["good morning from litellm"],
|
||||
# api_base="http://0.0.0.0:8000/"
|
||||
# )
|
||||
|
|
|
@ -11,17 +11,18 @@ import litellm
|
|||
from litellm import (
|
||||
embedding,
|
||||
completion,
|
||||
# AuthenticationError,
|
||||
# AuthenticationError,
|
||||
ContextWindowExceededError,
|
||||
# RateLimitError,
|
||||
# ServiceUnavailableError,
|
||||
# OpenAIError,
|
||||
# RateLimitError,
|
||||
# ServiceUnavailableError,
|
||||
# OpenAIError,
|
||||
)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import pytest
|
||||
|
||||
litellm.vertex_project = "pathrise-convert-1606954137718"
|
||||
litellm.vertex_location = "us-central1"
|
||||
litellm.num_retries=0
|
||||
litellm.num_retries = 0
|
||||
|
||||
# litellm.failure_callback = ["sentry"]
|
||||
#### What this tests ####
|
||||
|
@ -36,7 +37,8 @@ litellm.num_retries=0
|
|||
|
||||
models = ["command-nightly"]
|
||||
|
||||
# Test 1: Context Window Errors
|
||||
|
||||
# Test 1: Context Window Errors
|
||||
@pytest.mark.parametrize("model", models)
|
||||
def test_context_window(model):
|
||||
print("Testing context window error")
|
||||
|
@ -52,17 +54,27 @@ def test_context_window(model):
|
|||
print(f"Worked!")
|
||||
except RateLimitError:
|
||||
print("RateLimited!")
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
print(f"{e}")
|
||||
pytest.fail(f"An error occcurred - {e}")
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
def test_context_window_with_fallbacks(model):
|
||||
ctx_window_fallback_dict = {"command-nightly": "claude-2", "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k", "azure/chatgpt-v-2": "gpt-3.5-turbo-16k"}
|
||||
ctx_window_fallback_dict = {
|
||||
"command-nightly": "claude-2",
|
||||
"gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k",
|
||||
"azure/chatgpt-v-2": "gpt-3.5-turbo-16k",
|
||||
}
|
||||
sample_text = "how does a court case get to the Supreme Court?" * 1000
|
||||
messages = [{"content": sample_text, "role": "user"}]
|
||||
|
||||
completion(model=model, messages=messages, context_window_fallback_dict=ctx_window_fallback_dict)
|
||||
completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
context_window_fallback_dict=ctx_window_fallback_dict,
|
||||
)
|
||||
|
||||
|
||||
# for model in litellm.models_by_provider["bedrock"]:
|
||||
# test_context_window(model=model)
|
||||
|
@ -98,7 +110,9 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
|
|||
os.environ["AI21_API_KEY"] = "bad-key"
|
||||
elif "togethercomputer" in model:
|
||||
temporary_key = os.environ["TOGETHERAI_API_KEY"]
|
||||
os.environ["TOGETHERAI_API_KEY"] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a"
|
||||
os.environ[
|
||||
"TOGETHERAI_API_KEY"
|
||||
] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a"
|
||||
elif model in litellm.openrouter_models:
|
||||
temporary_key = os.environ["OPENROUTER_API_KEY"]
|
||||
os.environ["OPENROUTER_API_KEY"] = "bad-key"
|
||||
|
@ -115,9 +129,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
|
|||
temporary_key = os.environ["REPLICATE_API_KEY"]
|
||||
os.environ["REPLICATE_API_KEY"] = "bad-key"
|
||||
print(f"model: {model}")
|
||||
response = completion(
|
||||
model=model, messages=messages
|
||||
)
|
||||
response = completion(model=model, messages=messages)
|
||||
print(f"response: {response}")
|
||||
except AuthenticationError as e:
|
||||
print(f"AuthenticationError Caught Exception - {str(e)}")
|
||||
|
@ -148,23 +160,25 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
|
|||
os.environ["REPLICATE_API_KEY"] = temporary_key
|
||||
elif "j2" in model:
|
||||
os.environ["AI21_API_KEY"] = temporary_key
|
||||
elif ("togethercomputer" in model):
|
||||
elif "togethercomputer" in model:
|
||||
os.environ["TOGETHERAI_API_KEY"] = temporary_key
|
||||
elif model in litellm.aleph_alpha_models:
|
||||
os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key
|
||||
elif model in litellm.nlp_cloud_models:
|
||||
os.environ["NLP_CLOUD_API_KEY"] = temporary_key
|
||||
elif "bedrock" in model:
|
||||
elif "bedrock" in model:
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = temporary_aws_access_key
|
||||
os.environ["AWS_REGION_NAME"] = temporary_aws_region_name
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key
|
||||
return
|
||||
|
||||
|
||||
# for model in litellm.models_by_provider["bedrock"]:
|
||||
# invalid_auth(model=model)
|
||||
# invalid_auth(model="command-nightly")
|
||||
|
||||
# Test 3: Invalid Request Error
|
||||
|
||||
# Test 3: Invalid Request Error
|
||||
@pytest.mark.parametrize("model", models)
|
||||
def test_invalid_request_error(model):
|
||||
messages = [{"content": "hey, how's it going?", "role": "user"}]
|
||||
|
@ -173,23 +187,18 @@ def test_invalid_request_error(model):
|
|||
completion(model=model, messages=messages, max_tokens="hello world")
|
||||
|
||||
|
||||
|
||||
def test_completion_azure_exception():
|
||||
try:
|
||||
import openai
|
||||
|
||||
print("azure gpt-3.5 test\n\n")
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose = True
|
||||
## Test azure call
|
||||
old_azure_key = os.environ["AZURE_API_KEY"]
|
||||
os.environ["AZURE_API_KEY"] = "good morning"
|
||||
response = completion(
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello"
|
||||
}
|
||||
],
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
os.environ["AZURE_API_KEY"] = old_azure_key
|
||||
print(f"response: {response}")
|
||||
|
@ -199,25 +208,24 @@ def test_completion_azure_exception():
|
|||
print("good job got the correct error for azure when key not set")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_azure_exception()
|
||||
|
||||
|
||||
async def asynctest_completion_azure_exception():
|
||||
try:
|
||||
import openai
|
||||
import litellm
|
||||
|
||||
print("azure gpt-3.5 test\n\n")
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose = True
|
||||
## Test azure call
|
||||
old_azure_key = os.environ["AZURE_API_KEY"]
|
||||
os.environ["AZURE_API_KEY"] = "good morning"
|
||||
response = await litellm.acompletion(
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello"
|
||||
}
|
||||
],
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
print(response)
|
||||
|
@ -229,6 +237,8 @@ async def asynctest_completion_azure_exception():
|
|||
print("Got wrong exception")
|
||||
print("exception", e)
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# import asyncio
|
||||
# asyncio.run(
|
||||
# asynctest_completion_azure_exception()
|
||||
|
@ -239,19 +249,17 @@ def asynctest_completion_openai_exception_bad_model():
|
|||
try:
|
||||
import openai
|
||||
import litellm, asyncio
|
||||
|
||||
print("azure exception bad model\n\n")
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose = True
|
||||
|
||||
## Test azure call
|
||||
async def test():
|
||||
response = await litellm.acompletion(
|
||||
model="openai/gpt-6",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello"
|
||||
}
|
||||
],
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
asyncio.run(test())
|
||||
except openai.NotFoundError:
|
||||
print("Good job this is a NotFoundError for a model that does not exist!")
|
||||
|
@ -261,27 +269,25 @@ def asynctest_completion_openai_exception_bad_model():
|
|||
assert isinstance(e, openai.BadRequestError)
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# asynctest_completion_openai_exception_bad_model()
|
||||
|
||||
# asynctest_completion_openai_exception_bad_model()
|
||||
|
||||
|
||||
def asynctest_completion_azure_exception_bad_model():
|
||||
try:
|
||||
import openai
|
||||
import litellm, asyncio
|
||||
|
||||
print("azure exception bad model\n\n")
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose = True
|
||||
|
||||
## Test azure call
|
||||
async def test():
|
||||
response = await litellm.acompletion(
|
||||
model="azure/gpt-12",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello"
|
||||
}
|
||||
],
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
asyncio.run(test())
|
||||
except openai.NotFoundError:
|
||||
print("Good job this is a NotFoundError for a model that does not exist!")
|
||||
|
@ -290,25 +296,23 @@ def asynctest_completion_azure_exception_bad_model():
|
|||
print("Raised wrong type of exception", type(e))
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# asynctest_completion_azure_exception_bad_model()
|
||||
|
||||
|
||||
def test_completion_openai_exception():
|
||||
# test if openai:gpt raises openai.AuthenticationError
|
||||
try:
|
||||
import openai
|
||||
|
||||
print("openai gpt-3.5 test\n\n")
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose = True
|
||||
## Test azure call
|
||||
old_azure_key = os.environ["OPENAI_API_KEY"]
|
||||
os.environ["OPENAI_API_KEY"] = "good morning"
|
||||
response = completion(
|
||||
model="gpt-4",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello"
|
||||
}
|
||||
],
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
print(response)
|
||||
|
@ -317,25 +321,24 @@ def test_completion_openai_exception():
|
|||
print("OpenAI: good job got the correct error for openai when key not set")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_openai_exception()
|
||||
|
||||
|
||||
def test_completion_mistral_exception():
|
||||
# test if mistral/mistral-tiny raises openai.AuthenticationError
|
||||
try:
|
||||
import openai
|
||||
|
||||
print("Testing mistral ai exception mapping")
|
||||
litellm.set_verbose=True
|
||||
litellm.set_verbose = True
|
||||
## Test azure call
|
||||
old_azure_key = os.environ["MISTRAL_API_KEY"]
|
||||
os.environ["MISTRAL_API_KEY"] = "good morning"
|
||||
response = completion(
|
||||
model="mistral/mistral-tiny",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello"
|
||||
}
|
||||
],
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
print(response)
|
||||
|
@ -344,11 +347,11 @@ def test_completion_mistral_exception():
|
|||
print("good job got the correct error for openai when key not set")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_mistral_exception()
|
||||
|
||||
|
||||
|
||||
|
||||
# # test_invalid_request_error(model="command-nightly")
|
||||
# # Test 3: Rate Limit Errors
|
||||
# def test_model_call(model):
|
||||
|
@ -387,4 +390,4 @@ def test_completion_mistral_exception():
|
|||
# counts[result] += 1
|
||||
|
||||
# accuracy_score = counts[True]/(counts[True] + counts[False])
|
||||
# print(f"accuracy_score: {accuracy_score}")
|
||||
# print(f"accuracy_score: {accuracy_score}")
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue