From 4905929de3a2f75a5bae8fb5bf6b0ceb4853ea6f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 25 Dec 2023 14:10:38 +0530 Subject: [PATCH] refactor: add black formatting --- .pre-commit-config.yaml | 6 +- cookbook/benchmark/benchmark.py | 58 +- .../auto_evals.py | 21 +- cookbook/codellama-server/main.py | 55 +- cookbook/community-resources/get_hf_models.py | 57 +- cookbook/litellm-ollama-docker-image/test.py | 23 +- cookbook/litellm_router/load_test_proxy.py | 84 +- cookbook/litellm_router/load_test_queuing.py | 91 +- cookbook/litellm_router/load_test_router.py | 79 +- litellm/__init__.py | 261 +- litellm/_logging.py | 5 +- litellm/_redis.py | 44 +- litellm/budget_manager.py | 168 +- litellm/caching.py | 138 +- litellm/deprecated_litellm_server/__init__.py | 2 +- litellm/deprecated_litellm_server/main.py | 42 +- .../deprecated_litellm_server/server_utils.py | 35 +- litellm/exceptions.py | 84 +- litellm/integrations/custom_logger.py | 79 +- litellm/integrations/dynamodb.py | 26 +- litellm/integrations/langfuse.py | 5 +- litellm/integrations/langsmith.py | 21 +- litellm/integrations/litedebugger.py | 61 +- litellm/integrations/prompt_layer.py | 16 +- litellm/integrations/supabase.py | 1 + litellm/integrations/traceloop.py | 18 +- litellm/integrations/weights_biases.py | 40 +- litellm/llms/ai21.py | 156 +- litellm/llms/aleph_alpha.py | 238 +- litellm/llms/anthropic.py | 124 +- litellm/llms/azure.py | 647 +-- litellm/llms/base.py | 34 +- litellm/llms/baseten.py | 59 +- litellm/llms/bedrock.py | 600 +-- litellm/llms/cohere.py | 220 +- litellm/llms/custom_httpx/azure_dall_e_2.py | 38 +- litellm/llms/gemini.py | 160 +- litellm/llms/huggingface_restapi.py | 573 ++- litellm/llms/maritalk.py | 111 +- litellm/llms/nlp_cloud.py | 143 +- litellm/llms/ollama.py | 322 +- litellm/llms/oobabooga.py | 63 +- litellm/llms/openai.py | 1021 +++-- litellm/llms/openrouter.py | 41 +- litellm/llms/palm.py | 145 +- litellm/llms/petals.py | 171 +- litellm/llms/prompt_templates/factory.py | 403 +- litellm/llms/replicate.py | 236 +- litellm/llms/sagemaker.py | 264 +- litellm/llms/together_ai.py | 147 +- litellm/llms/vertex_ai.py | 514 ++- litellm/llms/vllm.py | 61 +- litellm/main.py | 1733 +++++--- litellm/proxy/__init__.py | 2 +- .../proxy/_experimental/post_call_rules.py | 6 +- litellm/proxy/_types.py | 116 +- litellm/proxy/custom_auth.py | 12 +- litellm/proxy/custom_callbacks.py | 32 +- litellm/proxy/health_check.py | 47 +- litellm/proxy/hooks/max_budget_limiter.py | 35 +- .../proxy/hooks/parallel_request_limiter.py | 78 +- litellm/proxy/proxy_cli.py | 358 +- litellm/proxy/proxy_server.py | 1390 ++++-- litellm/proxy/queue/celery_app.py | 132 +- litellm/proxy/queue/celery_worker.py | 7 +- litellm/proxy/queue/rq_worker.py | 56 +- .../tests/bursty_load_test_completion.py | 14 +- litellm/proxy/tests/load_test_completion.py | 9 +- litellm/proxy/tests/load_test_embedding.py | 14 +- .../proxy/tests/load_test_embedding_100.py | 18 +- .../proxy/tests/load_test_embedding_proxy.py | 14 +- litellm/proxy/tests/load_test_q.py | 69 +- litellm/proxy/tests/test_async.py | 7 +- litellm/proxy/tests/test_langchain_request.py | 4 - litellm/proxy/tests/test_q.py | 53 +- litellm/proxy/utils.py | 329 +- litellm/router.py | 1241 ++++-- litellm/router_strategy/least_busy.py | 99 +- litellm/tests/conftest.py | 17 +- litellm/tests/test_acooldowns_router.py | 136 +- litellm/tests/test_add_function_to_prompt.py | 90 +- .../tests/test_amazing_vertex_completion.py | 173 +- litellm/tests/test_async_fn.py | 81 +- litellm/tests/test_azure_perf.py | 151 +- litellm/tests/test_bad_params.py | 92 +- litellm/tests/test_batch_completions.py | 36 +- litellm/tests/test_budget_manager.py | 26 +- litellm/tests/test_caching.py | 469 +- litellm/tests/test_caching_ssl.py | 83 +- litellm/tests/test_class.py | 16 +- litellm/tests/test_completion.py | 829 ++-- litellm/tests/test_completion_with_retries.py | 22 +- litellm/tests/test_config.py | 77 +- litellm/tests/test_configs/custom_auth.py | 12 +- .../tests/test_configs/custom_callbacks.py | 66 +- litellm/tests/test_custom_callback_input.py | 786 ++-- litellm/tests/test_custom_callback_router.py | 683 +-- litellm/tests/test_dynamodb_logs.py | 52 +- litellm/tests/test_embedding.py | 170 +- litellm/tests/test_exceptions.py | 131 +- litellm/tests/test_function_calling.py | 65 +- litellm/tests/test_get_llm_provider.py | 4 +- litellm/tests/test_get_model_cost_map.py | 80 +- litellm/tests/test_get_model_file.py | 2 +- litellm/tests/test_hf_prompt_templates.py | 39 +- litellm/tests/test_image_generation.py | 54 +- litellm/tests/test_langchain_ChatLiteLLM.py | 1 - litellm/tests/test_langsmith.py | 68 +- litellm/tests/test_least_busy_routing.py | 16 +- litellm/tests/test_litellm_max_budget.py | 10 +- litellm/tests/test_loadtest_router.py | 18 +- litellm/tests/test_logging.py | 36 +- litellm/tests/test_longer_context_fallback.py | 3 +- litellm/tests/test_mock_request.py | 15 +- litellm/tests/test_model_alias_map.py | 8 +- litellm/tests/test_multiple_deployments.py | 56 +- litellm/tests/test_ollama.py | 21 +- litellm/tests/test_ollama_local.py | 65 +- litellm/tests/test_optional_params.py | 25 +- litellm/tests/test_profiling_router.py | 20 +- litellm/tests/test_prompt_factory.py | 11 +- litellm/tests/test_promptlayer_integration.py | 21 +- .../tests/test_provider_specific_config.py | 316 +- litellm/tests/test_proxy_custom_auth.py | 22 +- litellm/tests/test_proxy_custom_logger.py | 164 +- litellm/tests/test_proxy_exception_mapping.py | 72 +- litellm/tests/test_proxy_gunicorn.py | 12 +- litellm/tests/test_proxy_server.py | 150 +- litellm/tests/test_proxy_server_caching.py | 4 +- litellm/tests/test_proxy_server_cost.py | 16 +- litellm/tests/test_proxy_server_keys.py | 134 +- litellm/tests/test_proxy_server_langfuse.py | 38 +- litellm/tests/test_proxy_server_spend.py | 2 +- litellm/tests/test_register_model.py | 54 +- litellm/tests/test_router.py | 1543 +++---- litellm/tests/test_router_caching.py | 318 +- litellm/tests/test_router_fallbacks.py | 454 +- litellm/tests/test_router_get_deployments.py | 606 +-- litellm/tests/test_router_init.py | 62 +- litellm/tests/test_rules.py | 43 +- litellm/tests/test_stream_chunk_builder.py | 167 +- litellm/tests/test_streaming.py | 582 +-- litellm/tests/test_supabase_integration.py | 28 +- litellm/tests/test_text_completion.py | 2757 +++++++++++- litellm/tests/test_timeout.py | 41 +- litellm/tests/test_together_ai.py | 4 +- litellm/tests/test_token_counter.py | 90 +- litellm/tests/test_traceloop.py | 14 +- litellm/tests/test_utils.py | 235 +- litellm/tests/test_validate_environment.py | 2 +- litellm/tests/test_wandb.py | 58 +- litellm/timeout.py | 8 +- litellm/utils.py | 3809 +++++++++++------ ui/admin.py | 40 +- ui/app.py | 14 +- ui/auth.py | 11 +- 156 files changed, 19723 insertions(+), 10869 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88f268157..2a1336933 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,13 @@ repos: +- repo: https://github.com/psf/black + rev: stable + hooks: + - id: black - repo: https://github.com/pycqa/flake8 rev: 3.8.4 # The version of flake8 to use hooks: - id: flake8 - exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/ + exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/|^litellm/proxy/tests/ additional_dependencies: [flake8-print] files: litellm/.*\.py - repo: local diff --git a/cookbook/benchmark/benchmark.py b/cookbook/benchmark/benchmark.py index 882703a87..b38d185a1 100644 --- a/cookbook/benchmark/benchmark.py +++ b/cookbook/benchmark/benchmark.py @@ -9,33 +9,37 @@ import os # Define the list of models to benchmark # select any LLM listed here: https://docs.litellm.ai/docs/providers -models = ['gpt-3.5-turbo', 'claude-2'] +models = ["gpt-3.5-turbo", "claude-2"] # Enter LLM API keys # https://docs.litellm.ai/docs/providers -os.environ['OPENAI_API_KEY'] = "" -os.environ['ANTHROPIC_API_KEY'] = "" +os.environ["OPENAI_API_KEY"] = "" +os.environ["ANTHROPIC_API_KEY"] = "" # List of questions to benchmark (replace with your questions) -questions = [ - "When will BerriAI IPO?", - "When will LiteLLM hit $100M ARR?" -] +questions = ["When will BerriAI IPO?", "When will LiteLLM hit $100M ARR?"] -# Enter your system prompt here +# Enter your system prompt here system_prompt = """ You are LiteLLMs helpful assistant """ + @click.command() -@click.option('--system-prompt', default="You are a helpful assistant that can answer questions.", help="System prompt for the conversation.") +@click.option( + "--system-prompt", + default="You are a helpful assistant that can answer questions.", + help="System prompt for the conversation.", +) def main(system_prompt): for question in questions: data = [] # Data for the current question with tqdm(total=len(models)) as pbar: for model in models: - colored_description = colored(f"Running question: {question} for model: {model}", 'green') + colored_description = colored( + f"Running question: {question} for model: {model}", "green" + ) pbar.set_description(colored_description) start_time = time.time() @@ -44,35 +48,43 @@ def main(system_prompt): max_tokens=500, messages=[ {"role": "system", "content": system_prompt}, - {"role": "user", "content": question} + {"role": "user", "content": question}, ], ) end = time.time() total_time = end - start_time cost = completion_cost(completion_response=response) - raw_response = response['choices'][0]['message']['content'] + raw_response = response["choices"][0]["message"]["content"] - data.append({ - 'Model': colored(model, 'light_blue'), - 'Response': raw_response, # Colorize the response - 'ResponseTime': colored(f"{total_time:.2f} seconds", "red"), - 'Cost': colored(f"${cost:.6f}", 'green'), # Colorize the cost - }) + data.append( + { + "Model": colored(model, "light_blue"), + "Response": raw_response, # Colorize the response + "ResponseTime": colored(f"{total_time:.2f} seconds", "red"), + "Cost": colored(f"${cost:.6f}", "green"), # Colorize the cost + } + ) pbar.update(1) # Separate headers from the data - headers = ['Model', 'Response', 'Response Time (seconds)', 'Cost ($)'] + headers = ["Model", "Response", "Response Time (seconds)", "Cost ($)"] colwidths = [15, 80, 15, 10] # Create a nicely formatted table for the current question - table = tabulate([list(d.values()) for d in data], headers, tablefmt="grid", maxcolwidths=colwidths) - + table = tabulate( + [list(d.values()) for d in data], + headers, + tablefmt="grid", + maxcolwidths=colwidths, + ) + # Print the table for the current question - colored_question = colored(question, 'green') + colored_question = colored(question, "green") click.echo(f"\nBenchmark Results for '{colored_question}':") click.echo(table) # Display the formatted table -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/cookbook/benchmark/eval_suites_mlflow_autoevals/auto_evals.py b/cookbook/benchmark/eval_suites_mlflow_autoevals/auto_evals.py index fd76343c6..94682793a 100644 --- a/cookbook/benchmark/eval_suites_mlflow_autoevals/auto_evals.py +++ b/cookbook/benchmark/eval_suites_mlflow_autoevals/auto_evals.py @@ -1,25 +1,22 @@ import sys, os import traceback from dotenv import load_dotenv + load_dotenv() import litellm from litellm import embedding, completion, completion_cost from autoevals.llm import * + ################### import litellm # litellm completion call question = "which country has the highest population" response = litellm.completion( - model = "gpt-3.5-turbo", - messages = [ - { - "role": "user", - "content": question - } - ], + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": question}], ) print(response) # use the auto eval Factuality() evaluator @@ -27,9 +24,11 @@ print(response) print("calling evaluator") evaluator = Factuality() result = evaluator( - output=response.choices[0]["message"]["content"], # response from litellm.completion() - expected="India", # expected output - input=question # question passed to litellm.completion + output=response.choices[0]["message"][ + "content" + ], # response from litellm.completion() + expected="India", # expected output + input=question, # question passed to litellm.completion ) -print(result) \ No newline at end of file +print(result) diff --git a/cookbook/codellama-server/main.py b/cookbook/codellama-server/main.py index 51627ceca..a31220338 100644 --- a/cookbook/codellama-server/main.py +++ b/cookbook/codellama-server/main.py @@ -4,9 +4,10 @@ from flask_cors import CORS import traceback import litellm from util import handle_error -from litellm import completion -import os, dotenv, time +from litellm import completion +import os, dotenv, time import json + dotenv.load_dotenv() # TODO: set your keys in .env or here: @@ -19,57 +20,72 @@ verbose = True # litellm.caching_with_models = True # CACHING: caching_with_models Keys in the cache are messages + model. - to learn more: https://docs.litellm.ai/docs/caching/ ######### PROMPT LOGGING ########## -os.environ["PROMPTLAYER_API_KEY"] = "" # set your promptlayer key here - https://promptlayer.com/ +os.environ[ + "PROMPTLAYER_API_KEY" +] = "" # set your promptlayer key here - https://promptlayer.com/ # set callbacks litellm.success_callback = ["promptlayer"] ############ HELPER FUNCTIONS ################################### + def print_verbose(print_statement): if verbose: print(print_statement) + app = Flask(__name__) CORS(app) -@app.route('/') + +@app.route("/") def index(): - return 'received!', 200 + return "received!", 200 + def data_generator(response): for chunk in response: yield f"data: {json.dumps(chunk)}\n\n" -@app.route('/chat/completions', methods=["POST"]) + +@app.route("/chat/completions", methods=["POST"]) def api_completion(): data = request.json - start_time = time.time() - if data.get('stream') == "True": - data['stream'] = True # convert to boolean + start_time = time.time() + if data.get("stream") == "True": + data["stream"] = True # convert to boolean try: if "prompt" not in data: raise ValueError("data needs to have prompt") - data["model"] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct + data[ + "model" + ] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct # COMPLETION CALL system_prompt = "Only respond to questions about code. Say 'I don't know' to anything outside of that." - messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": data.pop("prompt")}] + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": data.pop("prompt")}, + ] data["messages"] = messages print(f"data: {data}") response = completion(**data) ## LOG SUCCESS - end_time = time.time() - if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses - return Response(data_generator(response), mimetype='text/event-stream') + end_time = time.time() + if ( + "stream" in data and data["stream"] == True + ): # use generate_responses to stream responses + return Response(data_generator(response), mimetype="text/event-stream") except Exception as e: # call handle_error function print_verbose(f"Got Error api_completion(): {traceback.format_exc()}") ## LOG FAILURE - end_time = time.time() + end_time = time.time() traceback_exception = traceback.format_exc() return handle_error(data=data) return response -@app.route('/get_models', methods=["POST"]) + +@app.route("/get_models", methods=["POST"]) def get_models(): try: return litellm.model_list @@ -78,7 +94,8 @@ def get_models(): response = {"error": str(e)} return response, 200 -if __name__ == "__main__": - from waitress import serve - serve(app, host="0.0.0.0", port=4000, threads=500) +if __name__ == "__main__": + from waitress import serve + + serve(app, host="0.0.0.0", port=4000, threads=500) diff --git a/cookbook/community-resources/get_hf_models.py b/cookbook/community-resources/get_hf_models.py index 6c245fa61..2d8972791 100644 --- a/cookbook/community-resources/get_hf_models.py +++ b/cookbook/community-resources/get_hf_models.py @@ -3,27 +3,28 @@ from urllib.parse import urlparse, parse_qs def get_next_url(response): - """ + """ Function to get 'next' url from Link header :param response: response from requests :return: next url or None """ - if 'link' not in response.headers: - return None - headers = response.headers + if "link" not in response.headers: + return None + headers = response.headers - next_url = headers['Link'] - print(next_url) - start_index = next_url.find("<") - end_index = next_url.find(">") + next_url = headers["Link"] + print(next_url) + start_index = next_url.find("<") + end_index = next_url.find(">") + + return next_url[1:end_index] - return next_url[1:end_index] def get_models(url): """ - Function to retrieve all models from paginated endpoint - :param url: base url to make GET request - :return: list of all models + Function to retrieve all models from paginated endpoint + :param url: base url to make GET request + :return: list of all models """ models = [] while url: @@ -36,19 +37,21 @@ def get_models(url): models.extend(payload) return models + def get_cleaned_models(models): """ - Function to clean retrieved models - :param models: list of retrieved models - :return: list of cleaned models + Function to clean retrieved models + :param models: list of retrieved models + :return: list of cleaned models """ cleaned_models = [] for model in models: cleaned_models.append(model["id"]) return cleaned_models + # Get text-generation models -url = 'https://huggingface.co/api/models?filter=text-generation-inference' +url = "https://huggingface.co/api/models?filter=text-generation-inference" text_generation_models = get_models(url) cleaned_text_generation_models = get_cleaned_models(text_generation_models) @@ -56,7 +59,7 @@ print(cleaned_text_generation_models) # Get conversational models -url = 'https://huggingface.co/api/models?filter=conversational' +url = "https://huggingface.co/api/models?filter=conversational" conversational_models = get_models(url) cleaned_conversational_models = get_cleaned_models(conversational_models) @@ -65,19 +68,23 @@ print(cleaned_conversational_models) def write_to_txt(cleaned_models, filename): """ - Function to write the contents of a list to a text file - :param cleaned_models: list of cleaned models - :param filename: name of the text file + Function to write the contents of a list to a text file + :param cleaned_models: list of cleaned models + :param filename: name of the text file """ - with open(filename, 'w') as f: + with open(filename, "w") as f: for item in cleaned_models: f.write("%s\n" % item) # Write contents of cleaned_text_generation_models to text_generation_models.txt -write_to_txt(cleaned_text_generation_models, 'huggingface_llms_metadata/hf_text_generation_models.txt') +write_to_txt( + cleaned_text_generation_models, + "huggingface_llms_metadata/hf_text_generation_models.txt", +) # Write contents of cleaned_conversational_models to conversational_models.txt -write_to_txt(cleaned_conversational_models, 'huggingface_llms_metadata/hf_conversational_models.txt') - - +write_to_txt( + cleaned_conversational_models, + "huggingface_llms_metadata/hf_conversational_models.txt", +) diff --git a/cookbook/litellm-ollama-docker-image/test.py b/cookbook/litellm-ollama-docker-image/test.py index d3fb04f16..977bd3699 100644 --- a/cookbook/litellm-ollama-docker-image/test.py +++ b/cookbook/litellm-ollama-docker-image/test.py @@ -1,4 +1,3 @@ - import openai api_base = f"http://0.0.0.0:8000" @@ -8,29 +7,29 @@ openai.api_key = "temp-key" print(openai.api_base) -print(f'LiteLLM: response from proxy with streaming') +print(f"LiteLLM: response from proxy with streaming") response = openai.ChatCompletion.create( - model="ollama/llama2", - messages = [ + model="ollama/llama2", + messages=[ { "role": "user", - "content": "this is a test request, acknowledge that you got it" + "content": "this is a test request, acknowledge that you got it", } ], - stream=True + stream=True, ) for chunk in response: - print(f'LiteLLM: streaming response from proxy {chunk}') + print(f"LiteLLM: streaming response from proxy {chunk}") response = openai.ChatCompletion.create( - model="ollama/llama2", - messages = [ + model="ollama/llama2", + messages=[ { "role": "user", - "content": "this is a test request, acknowledge that you got it" + "content": "this is a test request, acknowledge that you got it", } - ] + ], ) -print(f'LiteLLM: response from proxy {response}') +print(f"LiteLLM: response from proxy {response}") diff --git a/cookbook/litellm_router/load_test_proxy.py b/cookbook/litellm_router/load_test_proxy.py index b4f708b47..adba968ba 100644 --- a/cookbook/litellm_router/load_test_proxy.py +++ b/cookbook/litellm_router/load_test_proxy.py @@ -12,42 +12,51 @@ import pytest from litellm import Router import litellm -litellm.set_verbose=False + +litellm.set_verbose = False os.environ.pop("AZURE_AD_TOKEN") -model_list = [{ # list of model deployments - "model_name": "gpt-3.5-turbo", # model alias - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", # actual model name - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - } -}, { - "model_name": "gpt-3.5-turbo", - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-functioncalling", - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - } -}, { - "model_name": "gpt-3.5-turbo", - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo", - "api_key": os.getenv("OPENAI_API_KEY"), - } -}] +model_list = [ + { # list of model deployments + "model_name": "gpt-3.5-turbo", # model alias + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", # actual model name + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, +] router = Router(model_list=model_list) -file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"] +file_paths = [ + "test_questions/question1.txt", + "test_questions/question2.txt", + "test_questions/question3.txt", +] questions = [] for file_path in file_paths: try: print(file_path) - with open(file_path, 'r') as file: + with open(file_path, "r") as file: content = file.read() questions.append(content) except FileNotFoundError as e: @@ -59,10 +68,9 @@ for file_path in file_paths: # print(q) - # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. -# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere -# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions +# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere +# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions import concurrent.futures import random @@ -74,10 +82,18 @@ def make_openai_completion(question): try: start_time = time.time() import openai - client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000", + + client = openai.OpenAI( + api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000" + ) # base_url="http://0.0.0.0:8000", response = client.chat.completions.create( model="gpt-3.5-turbo", - messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}], + messages=[ + { + "role": "system", + "content": f"You are a helpful assistant. Answer this question{question}", + } + ], ) print(response) end_time = time.time() @@ -92,11 +108,10 @@ def make_openai_completion(question): except Exception as e: # Log exceptions for failed calls with open("error_log.txt", "a") as error_log_file: - error_log_file.write( - f"Question: {question[:100]}\nException: {str(e)}\n\n" - ) + error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n") return None + # Number of concurrent calls (you can adjust this) concurrent_calls = 100 @@ -133,4 +148,3 @@ with open("request_log.txt", "r") as log_file: with open("error_log.txt", "r") as error_log_file: print("\nError Log:\n", error_log_file.read()) - diff --git a/cookbook/litellm_router/load_test_queuing.py b/cookbook/litellm_router/load_test_queuing.py index f3acb8f04..7c22f2f42 100644 --- a/cookbook/litellm_router/load_test_queuing.py +++ b/cookbook/litellm_router/load_test_queuing.py @@ -12,42 +12,51 @@ import pytest from litellm import Router import litellm -litellm.set_verbose=False + +litellm.set_verbose = False # os.environ.pop("AZURE_AD_TOKEN") -model_list = [{ # list of model deployments - "model_name": "gpt-3.5-turbo", # model alias - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", # actual model name - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - } -}, { - "model_name": "gpt-3.5-turbo", - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-functioncalling", - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - } -}, { - "model_name": "gpt-3.5-turbo", - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo", - "api_key": os.getenv("OPENAI_API_KEY"), - } -}] +model_list = [ + { # list of model deployments + "model_name": "gpt-3.5-turbo", # model alias + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", # actual model name + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, +] router = Router(model_list=model_list) -file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"] +file_paths = [ + "test_questions/question1.txt", + "test_questions/question2.txt", + "test_questions/question3.txt", +] questions = [] for file_path in file_paths: try: print(file_path) - with open(file_path, 'r') as file: + with open(file_path, "r") as file: content = file.read() questions.append(content) except FileNotFoundError as e: @@ -59,10 +68,9 @@ for file_path in file_paths: # print(q) - # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. -# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere -# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions +# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere +# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions import concurrent.futures import random @@ -76,9 +84,12 @@ def make_openai_completion(question): import requests data = { - 'model': 'gpt-3.5-turbo', - 'messages': [ - {'role': 'system', 'content': f'You are a helpful assistant. Answer this question{question}'}, + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "system", + "content": f"You are a helpful assistant. Answer this question{question}", + }, ], } response = requests.post("http://0.0.0.0:8000/queue/request", json=data) @@ -89,8 +100,8 @@ def make_openai_completion(question): log_file.write( f"Question: {question[:100]}\nResponse ID: {response.get('id', 'N/A')} Url: {response.get('url', 'N/A')}\nTime: {end_time - start_time:.2f} seconds\n\n" ) - - # polling the url + + # polling the url while True: try: url = response["url"] @@ -107,7 +118,9 @@ def make_openai_completion(question): ) break - print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") + print( + f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}" + ) time.sleep(0.5) except Exception as e: print("got exception in polling", e) @@ -117,11 +130,10 @@ def make_openai_completion(question): except Exception as e: # Log exceptions for failed calls with open("error_log.txt", "a") as error_log_file: - error_log_file.write( - f"Question: {question[:100]}\nException: {str(e)}\n\n" - ) + error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n") return None + # Number of concurrent calls (you can adjust this) concurrent_calls = 10 @@ -142,7 +154,7 @@ successful_calls = 0 failed_calls = 0 for future in futures: - if future.done(): + if future.done(): if future.result() is not None: successful_calls += 1 else: @@ -152,4 +164,3 @@ print(f"Load test Summary:") print(f"Total Requests: {concurrent_calls}") print(f"Successful Calls: {successful_calls}") print(f"Failed Calls: {failed_calls}") - diff --git a/cookbook/litellm_router/load_test_router.py b/cookbook/litellm_router/load_test_router.py index a9568d14b..5eed3867d 100644 --- a/cookbook/litellm_router/load_test_router.py +++ b/cookbook/litellm_router/load_test_router.py @@ -12,42 +12,51 @@ import pytest from litellm import Router import litellm -litellm.set_verbose=False + +litellm.set_verbose = False os.environ.pop("AZURE_AD_TOKEN") -model_list = [{ # list of model deployments - "model_name": "gpt-3.5-turbo", # model alias - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", # actual model name - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - } -}, { - "model_name": "gpt-3.5-turbo", - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-functioncalling", - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - } -}, { - "model_name": "gpt-3.5-turbo", - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo", - "api_key": os.getenv("OPENAI_API_KEY"), - } -}] +model_list = [ + { # list of model deployments + "model_name": "gpt-3.5-turbo", # model alias + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", # actual model name + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, +] router = Router(model_list=model_list) -file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"] +file_paths = [ + "test_questions/question1.txt", + "test_questions/question2.txt", + "test_questions/question3.txt", +] questions = [] for file_path in file_paths: try: print(file_path) - with open(file_path, 'r') as file: + with open(file_path, "r") as file: content = file.read() questions.append(content) except FileNotFoundError as e: @@ -59,10 +68,9 @@ for file_path in file_paths: # print(q) - # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. -# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere -# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions +# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere +# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions import concurrent.futures import random @@ -75,7 +83,12 @@ def make_openai_completion(question): start_time = time.time() response = router.completion( model="gpt-3.5-turbo", - messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}], + messages=[ + { + "role": "system", + "content": f"You are a helpful assistant. Answer this question{question}", + } + ], ) print(response) end_time = time.time() @@ -90,11 +103,10 @@ def make_openai_completion(question): except Exception as e: # Log exceptions for failed calls with open("error_log.txt", "a") as error_log_file: - error_log_file.write( - f"Question: {question[:100]}\nException: {str(e)}\n\n" - ) + error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n") return None + # Number of concurrent calls (you can adjust this) concurrent_calls = 150 @@ -131,4 +143,3 @@ with open("request_log.txt", "r") as log_file: with open("error_log.txt", "r") as error_log_file: print("\nError Log:\n", error_log_file.read()) - diff --git a/litellm/__init__.py b/litellm/__init__.py index ce606c8fd..e1ce37dfa 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -9,9 +9,15 @@ input_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = [] callbacks: List[Callable] = [] -_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. -_async_success_callback: List[Union[str, Callable]] = [] # internal variable - async custom callbacks are routed here. -_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. +_async_input_callback: List[ + Callable +] = [] # internal variable - async custom callbacks are routed here. +_async_success_callback: List[ + Union[str, Callable] +] = [] # internal variable - async custom callbacks are routed here. +_async_failure_callback: List[ + Callable +] = [] # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] email: Optional[ @@ -42,20 +48,88 @@ aleph_alpha_key: Optional[str] = None nlp_cloud_key: Optional[str] = None use_client: bool = False logging: bool = True -caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.ai/docs/caching +cache: Optional[ + Cache +] = None # cache object <- use this - https://docs.litellm.ai/docs/caching model_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {} -max_budget: float = 0.0 # set the max budget across all providers -_openai_completion_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] -_litellm_completion_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"] -_current_cost = 0 # private variable, used if max budget is set +max_budget: float = 0.0 # set the max budget across all providers +_openai_completion_params = [ + "functions", + "function_call", + "temperature", + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", +] +_litellm_completion_params = [ + "metadata", + "acompletion", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "input_cost_per_token", + "output_cost_per_token", + "hf_model_name", + "model_info", + "proxy_server_request", + "preset_cache_key", +] +_current_cost = 0 # private variable, used if max budget is set error_logs: Dict = {} -add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt +add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt client_session: Optional[httpx.Client] = None aclient_session: Optional[httpx.AsyncClient] = None -model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' +model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" suppress_debug_info = False dynamodb_table_name: Optional[str] = None @@ -66,23 +140,35 @@ fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None allowed_fails: int = 0 ####### SECRET MANAGERS ##################### -secret_manager_client: Optional[Any] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +secret_manager_client: Optional[ + Any +] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. ############################################# + def get_model_cost_map(url: str): try: - with requests.get(url, timeout=5) as response: # set a 5 second timeout for the get request - response.raise_for_status() # Raise an exception if the request is unsuccessful + with requests.get( + url, timeout=5 + ) as response: # set a 5 second timeout for the get request + response.raise_for_status() # Raise an exception if the request is unsuccessful content = response.json() return content except Exception as e: import importlib.resources import json - with importlib.resources.open_text("litellm", "model_prices_and_context_window_backup.json") as f: + + with importlib.resources.open_text( + "litellm", "model_prices_and_context_window_backup.json" + ) as f: content = json.load(f) return content + + model_cost = get_model_cost_map(url=model_cost_map_url) -custom_prompt_dict:Dict[str, dict] = {} +custom_prompt_dict: Dict[str, dict] = {} + + ####### THREAD-SPECIFIC DATA ################### class MyLocal(threading.local): def __init__(self): @@ -123,56 +209,51 @@ bedrock_models: List = [] deepinfra_models: List = [] perplexity_models: List = [] for key, value in model_cost.items(): - if value.get('litellm_provider') == 'openai': + if value.get("litellm_provider") == "openai": open_ai_chat_completion_models.append(key) - elif value.get('litellm_provider') == 'text-completion-openai': + elif value.get("litellm_provider") == "text-completion-openai": open_ai_text_completion_models.append(key) - elif value.get('litellm_provider') == 'cohere': + elif value.get("litellm_provider") == "cohere": cohere_models.append(key) - elif value.get('litellm_provider') == 'anthropic': + elif value.get("litellm_provider") == "anthropic": anthropic_models.append(key) - elif value.get('litellm_provider') == 'openrouter': + elif value.get("litellm_provider") == "openrouter": openrouter_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-text-models': + elif value.get("litellm_provider") == "vertex_ai-text-models": vertex_text_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-code-text-models': + elif value.get("litellm_provider") == "vertex_ai-code-text-models": vertex_code_text_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-language-models': + elif value.get("litellm_provider") == "vertex_ai-language-models": vertex_language_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-vision-models': + elif value.get("litellm_provider") == "vertex_ai-vision-models": vertex_vision_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-chat-models': + elif value.get("litellm_provider") == "vertex_ai-chat-models": vertex_chat_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-code-chat-models': + elif value.get("litellm_provider") == "vertex_ai-code-chat-models": vertex_code_chat_models.append(key) - elif value.get('litellm_provider') == 'ai21': + elif value.get("litellm_provider") == "ai21": ai21_models.append(key) - elif value.get('litellm_provider') == 'nlp_cloud': + elif value.get("litellm_provider") == "nlp_cloud": nlp_cloud_models.append(key) - elif value.get('litellm_provider') == 'aleph_alpha': + elif value.get("litellm_provider") == "aleph_alpha": aleph_alpha_models.append(key) - elif value.get('litellm_provider') == 'bedrock': + elif value.get("litellm_provider") == "bedrock": bedrock_models.append(key) - elif value.get('litellm_provider') == 'deepinfra': + elif value.get("litellm_provider") == "deepinfra": deepinfra_models.append(key) - elif value.get('litellm_provider') == 'perplexity': + elif value.get("litellm_provider") == "perplexity": perplexity_models.append(key) # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary openai_compatible_endpoints: List = [ - "api.perplexity.ai", + "api.perplexity.ai", "api.endpoints.anyscale.com/v1", "api.deepinfra.com/v1/openai", - "api.mistral.ai/v1" + "api.mistral.ai/v1", ] # this is maintained for Exception Mapping -openai_compatible_providers: List = [ - "anyscale", - "mistral", - "deepinfra", - "perplexity" -] +openai_compatible_providers: List = ["anyscale", "mistral", "deepinfra", "perplexity"] # well supported replicate llms @@ -209,23 +290,18 @@ huggingface_models: List = [ together_ai_models: List = [ # llama llms - chat "togethercomputer/llama-2-70b-chat", - - # llama llms - language / instruct + # llama llms - language / instruct "togethercomputer/llama-2-70b", "togethercomputer/LLaMA-2-7B-32K", "togethercomputer/Llama-2-7B-32K-Instruct", "togethercomputer/llama-2-7b", - # falcon llms "togethercomputer/falcon-40b-instruct", "togethercomputer/falcon-7b-instruct", - # alpaca "togethercomputer/alpaca-7b", - # chat llms "HuggingFaceH4/starchat-alpha", - # code llms "togethercomputer/CodeLlama-34b", "togethercomputer/CodeLlama-34b-Instruct", @@ -234,29 +310,27 @@ together_ai_models: List = [ "NumbersStation/nsql-llama-2-7B", "WizardLM/WizardCoder-15B-V1.0", "WizardLM/WizardCoder-Python-34B-V1.0", - # language llms "NousResearch/Nous-Hermes-Llama2-13b", "Austism/chronos-hermes-13b", "upstage/SOLAR-0-70b-16bit", "WizardLM/WizardLM-70B-V1.0", - -] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...) +] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...) -baseten_models: List = ["qvv0xeq", "q841o8w", "31dxrj3"] # FALCON 7B # WizardLM # Mosaic ML +baseten_models: List = [ + "qvv0xeq", + "q841o8w", + "31dxrj3", +] # FALCON 7B # WizardLM # Mosaic ML petals_models = [ "petals-team/StableBeluga2", ] -ollama_models = [ - "llama2" -] +ollama_models = ["llama2"] -maritalk_models = [ - "maritalk" -] +maritalk_models = ["maritalk"] model_list = ( open_ai_chat_completion_models @@ -308,7 +382,7 @@ provider_list: List = [ "anyscale", "mistral", "maritalk", - "custom", # custom apis + "custom", # custom apis ] models_by_provider: dict = { @@ -327,28 +401,28 @@ models_by_provider: dict = { "ollama": ollama_models, "deepinfra": deepinfra_models, "perplexity": perplexity_models, - "maritalk": maritalk_models + "maritalk": maritalk_models, } -# mapping for those models which have larger equivalents +# mapping for those models which have larger equivalents longer_context_model_fallback_dict: dict = { # openai chat completion models - "gpt-3.5-turbo": "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301", - "gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613", - "gpt-4": "gpt-4-32k", - "gpt-4-0314": "gpt-4-32k-0314", - "gpt-4-0613": "gpt-4-32k-0613", - # anthropic - "claude-instant-1": "claude-2", + "gpt-3.5-turbo": "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301", + "gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613", + "gpt-4": "gpt-4-32k", + "gpt-4-0314": "gpt-4-32k-0314", + "gpt-4-0613": "gpt-4-32k-0613", + # anthropic + "claude-instant-1": "claude-2", "claude-instant-1.2": "claude-2", # vertexai "chat-bison": "chat-bison-32k", "chat-bison@001": "chat-bison-32k", - "codechat-bison": "codechat-bison-32k", + "codechat-bison": "codechat-bison-32k", "codechat-bison@001": "codechat-bison-32k", - # openrouter - "openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k", + # openrouter + "openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k", "openrouter/anthropic/claude-instant-v1": "openrouter/anthropic/claude-2", } @@ -357,20 +431,23 @@ open_ai_embedding_models: List = ["text-embedding-ada-002"] cohere_embedding_models: List = [ "embed-english-v3.0", "embed-english-light-v3.0", - "embed-multilingual-v3.0", - "embed-english-v2.0", - "embed-english-light-v2.0", - "embed-multilingual-v2.0", + "embed-multilingual-v3.0", + "embed-english-v2.0", + "embed-english-light-v2.0", + "embed-multilingual-v2.0", +] +bedrock_embedding_models: List = [ + "amazon.titan-embed-text-v1", + "cohere.embed-english-v3", + "cohere.embed-multilingual-v3", ] -bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] -all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models +all_embedding_models = ( + open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models +) ####### IMAGE GENERATION MODELS ################### -openai_image_generation_models = [ - "dall-e-2", - "dall-e-3" -] +openai_image_generation_models = ["dall-e-2", "dall-e-3"] from .timeout import timeout @@ -394,11 +471,11 @@ from .utils import ( get_llm_provider, completion_with_config, register_model, - encode, - decode, + encode, + decode, _calculate_retry_after, _should_retry, - get_secret + get_secret, ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig @@ -415,7 +492,13 @@ from .llms.vertex_ai import VertexAIConfig from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.maritalk import MaritTalkConfig -from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig, AmazonLlamaConfig +from .llms.bedrock import ( + AmazonTitanConfig, + AmazonAI21Config, + AmazonAnthropicConfig, + AmazonCohereConfig, + AmazonLlamaConfig, +) from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.azure import AzureOpenAIConfig from .main import * # type: ignore @@ -429,13 +512,13 @@ from .exceptions import ( ServiceUnavailableError, OpenAIError, ContextWindowExceededError, - BudgetExceededError, + BudgetExceededError, APIError, Timeout, APIConnectionError, - APIResponseValidationError, - UnprocessableEntityError + APIResponseValidationError, + UnprocessableEntityError, ) from .budget_manager import BudgetManager from .proxy.proxy_cli import run_server -from .router import Router \ No newline at end of file +from .router import Router diff --git a/litellm/_logging.py b/litellm/_logging.py index 0c6881406..0bcc89dde 100644 --- a/litellm/_logging.py +++ b/litellm/_logging.py @@ -1,8 +1,9 @@ set_verbose = False + def print_verbose(print_statement): try: if set_verbose: - print(print_statement) # noqa + print(print_statement) # noqa except: - pass \ No newline at end of file + pass diff --git a/litellm/_redis.py b/litellm/_redis.py index eca517079..bee73f134 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -13,6 +13,7 @@ import inspect import redis, litellm from typing import List, Optional + def _get_redis_kwargs(): arg_spec = inspect.getfullargspec(redis.Redis) @@ -23,32 +24,26 @@ def _get_redis_kwargs(): "retry", } - - include_args = [ - "url" - ] + include_args = ["url"] - available_args = [ - x for x in arg_spec.args if x not in exclude_args - ] + include_args + available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args return available_args + def _get_redis_env_kwarg_mapping(): PREFIX = "REDIS_" - return { - f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs() - } + return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()} def _redis_kwargs_from_environment(): mapping = _get_redis_env_kwarg_mapping() - return_dict = {} + return_dict = {} for k, v in mapping.items(): - value = litellm.get_secret(k, default_value=None) # check os.environ/key vault - if value is not None: + value = litellm.get_secret(k, default_value=None) # check os.environ/key vault + if value is not None: return_dict[v] = value return return_dict @@ -56,21 +51,26 @@ def _redis_kwargs_from_environment(): def get_redis_url_from_environment(): if "REDIS_URL" in os.environ: return os.environ["REDIS_URL"] - + if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ: - raise ValueError("Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis.") + raise ValueError( + "Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis." + ) if "REDIS_PASSWORD" in os.environ: redis_password = f":{os.environ['REDIS_PASSWORD']}@" else: redis_password = "" - return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}" + return ( + f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}" + ) + def get_redis_client(**env_overrides): ### check if "os.environ/" passed in - for k, v in env_overrides.items(): - if isinstance(v, str) and v.startswith("os.environ/"): + for k, v in env_overrides.items(): + if isinstance(v, str) and v.startswith("os.environ/"): v = v.replace("os.environ/", "") value = litellm.get_secret(v) env_overrides[k] = value @@ -80,14 +80,14 @@ def get_redis_client(**env_overrides): **env_overrides, } - if "url" in redis_kwargs and redis_kwargs['url'] is not None: + if "url" in redis_kwargs and redis_kwargs["url"] is not None: redis_kwargs.pop("host", None) redis_kwargs.pop("port", None) redis_kwargs.pop("db", None) redis_kwargs.pop("password", None) - + return redis.Redis.from_url(**redis_kwargs) - elif "host" not in redis_kwargs or redis_kwargs['host'] is None: + elif "host" not in redis_kwargs or redis_kwargs["host"] is None: raise ValueError("Either 'host' or 'url' must be specified for redis.") litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") - return redis.Redis(**redis_kwargs) \ No newline at end of file + return redis.Redis(**redis_kwargs) diff --git a/litellm/budget_manager.py b/litellm/budget_manager.py index 07468e2f5..4a3bb2cae 100644 --- a/litellm/budget_manager.py +++ b/litellm/budget_manager.py @@ -1,119 +1,166 @@ import os, json, time -import litellm +import litellm from litellm.utils import ModelResponse import requests, threading from typing import Optional, Union, Literal + class BudgetManager: - def __init__(self, project_name: str, client_type: str = "local", api_base: Optional[str] = None): + def __init__( + self, + project_name: str, + client_type: str = "local", + api_base: Optional[str] = None, + ): self.client_type = client_type self.project_name = project_name self.api_base = api_base or "https://api.litellm.ai" ## load the data or init the initial dictionaries - self.load_data() - + self.load_data() + def print_verbose(self, print_statement): try: if litellm.set_verbose: import logging + logging.info(print_statement) except: pass - + def load_data(self): if self.client_type == "local": # Check if user dict file exists if os.path.isfile("user_cost.json"): # Load the user dict - with open("user_cost.json", 'r') as json_file: + with open("user_cost.json", "r") as json_file: self.user_dict = json.load(json_file) else: self.print_verbose("User Dictionary not found!") - self.user_dict = {} + self.user_dict = {} self.print_verbose(f"user dict from local: {self.user_dict}") elif self.client_type == "hosted": # Load the user_dict from hosted db url = self.api_base + "/get_budget" - headers = {'Content-Type': 'application/json'} - data = { - 'project_name' : self.project_name - } + headers = {"Content-Type": "application/json"} + data = {"project_name": self.project_name} response = requests.post(url, headers=headers, json=data) response = response.json() if response["status"] == "error": - self.user_dict = {} # assume this means the user dict hasn't been stored yet + self.user_dict = ( + {} + ) # assume this means the user dict hasn't been stored yet else: self.user_dict = response["data"] - def create_budget(self, total_budget: float, user: str, duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None, created_at: float = time.time()): + def create_budget( + self, + total_budget: float, + user: str, + duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None, + created_at: float = time.time(), + ): self.user_dict[user] = {"total_budget": total_budget} if duration is None: return self.user_dict[user] - - if duration == 'daily': + + if duration == "daily": duration_in_days = 1 - elif duration == 'weekly': + elif duration == "weekly": duration_in_days = 7 - elif duration == 'monthly': + elif duration == "monthly": duration_in_days = 28 - elif duration == 'yearly': + elif duration == "yearly": duration_in_days = 365 else: - raise ValueError("""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""") - self.user_dict[user] = {"total_budget": total_budget, "duration": duration_in_days, "created_at": created_at, "last_updated_at": created_at} - self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution + raise ValueError( + """duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""" + ) + self.user_dict[user] = { + "total_budget": total_budget, + "duration": duration_in_days, + "created_at": created_at, + "last_updated_at": created_at, + } + self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution return self.user_dict[user] - + def projected_cost(self, model: str, messages: list, user: str): text = "".join(message["content"] for message in messages) prompt_tokens = litellm.token_counter(model=model, text=text) - prompt_cost, _ = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=0) + prompt_cost, _ = litellm.cost_per_token( + model=model, prompt_tokens=prompt_tokens, completion_tokens=0 + ) current_cost = self.user_dict[user].get("current_cost", 0) projected_cost = prompt_cost + current_cost return projected_cost - + def get_total_budget(self, user: str): return self.user_dict[user]["total_budget"] - - def update_cost(self, user: str, completion_obj: Optional[ModelResponse] = None, model: Optional[str] = None, input_text: Optional[str] = None, output_text: Optional[str] = None): - if model and input_text and output_text: - prompt_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": input_text}]) - completion_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": output_text}]) - prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar - elif completion_obj: - cost = litellm.completion_cost(completion_response=completion_obj) - model = completion_obj['model'] # if this throws an error try, model = completion_obj['model'] - else: - raise ValueError("Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager") - self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get("current_cost", 0) + def update_cost( + self, + user: str, + completion_obj: Optional[ModelResponse] = None, + model: Optional[str] = None, + input_text: Optional[str] = None, + output_text: Optional[str] = None, + ): + if model and input_text and output_text: + prompt_tokens = litellm.token_counter( + model=model, messages=[{"role": "user", "content": input_text}] + ) + completion_tokens = litellm.token_counter( + model=model, messages=[{"role": "user", "content": output_text}] + ) + ( + prompt_tokens_cost_usd_dollar, + completion_tokens_cost_usd_dollar, + ) = litellm.cost_per_token( + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar + elif completion_obj: + cost = litellm.completion_cost(completion_response=completion_obj) + model = completion_obj[ + "model" + ] # if this throws an error try, model = completion_obj['model'] + else: + raise ValueError( + "Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager" + ) + + self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get( + "current_cost", 0 + ) if "model_cost" in self.user_dict[user]: - self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user]["model_cost"].get(model, 0) + self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][ + "model_cost" + ].get(model, 0) else: self.user_dict[user]["model_cost"] = {model: cost} - self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution + self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution return {"user": self.user_dict[user]} - - + def get_current_cost(self, user): return self.user_dict[user].get("current_cost", 0) - + def get_model_cost(self, user): return self.user_dict[user].get("model_cost", 0) - + def is_valid_user(self, user: str) -> bool: return user in self.user_dict - + def get_users(self): return list(self.user_dict.keys()) - + def reset_cost(self, user): self.user_dict[user]["current_cost"] = 0 self.user_dict[user]["model_cost"] = {} return {"user": self.user_dict[user]} - + def reset_on_duration(self, user: str): # Get current and creation time last_updated_at = self.user_dict[user]["last_updated_at"] @@ -121,38 +168,39 @@ class BudgetManager: # Convert duration from days to seconds duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60 - + # Check if duration has elapsed if current_time - last_updated_at >= duration_in_seconds: # Reset cost if duration has elapsed and update the creation time self.reset_cost(user) self.user_dict[user]["last_updated_at"] = current_time self._save_data_thread() # Save the data - + def update_budget_all_users(self): for user in self.get_users(): if "duration" in self.user_dict[user]: self.reset_on_duration(user) def _save_data_thread(self): - thread = threading.Thread(target=self.save_data) # [Non-Blocking]: saves data without blocking execution + thread = threading.Thread( + target=self.save_data + ) # [Non-Blocking]: saves data without blocking execution thread.start() def save_data(self): if self.client_type == "local": - import json - - # save the user dict - with open("user_cost.json", 'w') as json_file: - json.dump(self.user_dict, json_file, indent=4) # Indent for pretty formatting + import json + + # save the user dict + with open("user_cost.json", "w") as json_file: + json.dump( + self.user_dict, json_file, indent=4 + ) # Indent for pretty formatting return {"status": "success"} elif self.client_type == "hosted": url = self.api_base + "/set_budget" - headers = {'Content-Type': 'application/json'} - data = { - 'project_name' : self.project_name, - "user_dict": self.user_dict - } + headers = {"Content-Type": "application/json"} + data = {"project_name": self.project_name, "user_dict": self.user_dict} response = requests.post(url, headers=headers, json=data) response = response.json() - return response \ No newline at end of file + return response diff --git a/litellm/caching.py b/litellm/caching.py index 73dde7cf9..78dbaf403 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -12,13 +12,15 @@ import time, logging import json, traceback, ast from typing import Optional, Literal, List + def print_verbose(print_statement): try: if litellm.set_verbose: - print(print_statement) # noqa + print(print_statement) # noqa except: pass + class BaseCache: def set_cache(self, key, value, **kwargs): raise NotImplementedError @@ -45,13 +47,13 @@ class InMemoryCache(BaseCache): self.cache_dict.pop(key, None) return None original_cached_response = self.cache_dict[key] - try: + try: cached_response = json.loads(original_cached_response) - except: + except: cached_response = original_cached_response return cached_response return None - + def flush_cache(self): self.cache_dict.clear() self.ttl_dict.clear() @@ -60,17 +62,18 @@ class InMemoryCache(BaseCache): class RedisCache(BaseCache): def __init__(self, host=None, port=None, password=None, **kwargs): import redis + # if users don't provider one, use the default litellm cache from ._redis import get_redis_client redis_kwargs = {} - if host is not None: + if host is not None: redis_kwargs["host"] = host if port is not None: redis_kwargs["port"] = port - if password is not None: + if password is not None: redis_kwargs["password"] = password - + redis_kwargs.update(kwargs) self.redis_client = get_redis_client(**redis_kwargs) @@ -88,13 +91,19 @@ class RedisCache(BaseCache): try: print_verbose(f"Get Redis Cache: key: {key}") cached_response = self.redis_client.get(key) - print_verbose(f"Got Redis Cache: key: {key}, cached_response {cached_response}") + print_verbose( + f"Got Redis Cache: key: {key}, cached_response {cached_response}" + ) if cached_response != None: # cached_response is in `b{} convert it to ModelResponse - cached_response = cached_response.decode("utf-8") # Convert bytes to string - try: - cached_response = json.loads(cached_response) # Convert string to dictionary - except: + cached_response = cached_response.decode( + "utf-8" + ) # Convert bytes to string + try: + cached_response = json.loads( + cached_response + ) # Convert string to dictionary + except: cached_response = ast.literal_eval(cached_response) return cached_response except Exception as e: @@ -105,34 +114,40 @@ class RedisCache(BaseCache): def flush_cache(self): self.redis_client.flushall() -class DualCache(BaseCache): + +class DualCache(BaseCache): """ - This updates both Redis and an in-memory cache simultaneously. - When data is updated or inserted, it is written to both the in-memory cache + Redis. + This updates both Redis and an in-memory cache simultaneously. + When data is updated or inserted, it is written to both the in-memory cache + Redis. This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data. """ - def __init__(self, in_memory_cache: Optional[InMemoryCache] =None, redis_cache: Optional[RedisCache] =None) -> None: + + def __init__( + self, + in_memory_cache: Optional[InMemoryCache] = None, + redis_cache: Optional[RedisCache] = None, + ) -> None: super().__init__() # If in_memory_cache is not provided, use the default InMemoryCache self.in_memory_cache = in_memory_cache or InMemoryCache() # If redis_cache is not provided, use the default RedisCache self.redis_cache = redis_cache - + def set_cache(self, key, value, **kwargs): # Update both Redis and in-memory cache - try: + try: print_verbose(f"set cache: key: {key}; value: {value}") if self.in_memory_cache is not None: self.in_memory_cache.set_cache(key, value, **kwargs) if self.redis_cache is not None: self.redis_cache.set_cache(key, value, **kwargs) - except Exception as e: + except Exception as e: print_verbose(e) def get_cache(self, key, **kwargs): # Try to fetch from in-memory cache first - try: + try: print_verbose(f"get cache: cache key: {key}") result = None if self.in_memory_cache is not None: @@ -141,7 +156,7 @@ class DualCache(BaseCache): if in_memory_result is not None: result = in_memory_result - if self.redis_cache is not None: + if self.redis_cache is not None: # If not found in in-memory cache, try fetching from Redis redis_result = self.redis_cache.get_cache(key, **kwargs) @@ -153,25 +168,28 @@ class DualCache(BaseCache): print_verbose(f"get cache: cache result: {result}") return result - except Exception as e: + except Exception as e: traceback.print_exc() - + def flush_cache(self): if self.in_memory_cache is not None: self.in_memory_cache.flush_cache() if self.redis_cache is not None: self.redis_cache.flush_cache() + #### LiteLLM.Completion / Embedding Cache #### class Cache: def __init__( - self, - type: Optional[Literal["local", "redis"]] = "local", - host: Optional[str] = None, - port: Optional[str] = None, - password: Optional[str] = None, - supported_call_types: Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"], - **kwargs + self, + type: Optional[Literal["local", "redis"]] = "local", + host: Optional[str] = None, + port: Optional[str] = None, + password: Optional[str] = None, + supported_call_types: Optional[ + List[Literal["completion", "acompletion", "embedding", "aembedding"]] + ] = ["completion", "acompletion", "embedding", "aembedding"], + **kwargs, ): """ Initializes the cache based on the given type. @@ -200,7 +218,7 @@ class Cache: litellm.success_callback.append("cache") if "cache" not in litellm._async_success_callback: litellm._async_success_callback.append("cache") - self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"] + self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"] def get_cache_key(self, *args, **kwargs): """ @@ -215,18 +233,37 @@ class Cache: """ cache_key = "" print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}") - + # for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None: print_verbose(f"\nReturning preset cache key: {cache_key}") return kwargs.get("litellm_params", {}).get("preset_cache_key", None) # sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4] - completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"] - embedding_only_kwargs = ["input", "encoding_format"] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs - + completion_kwargs = [ + "model", + "messages", + "temperature", + "top_p", + "n", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "response_format", + "seed", + "tools", + "tool_choice", + ] + embedding_only_kwargs = [ + "input", + "encoding_format", + ] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs + # combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set() - combined_kwargs = completion_kwargs + embedding_only_kwargs + combined_kwargs = completion_kwargs + embedding_only_kwargs for param in combined_kwargs: # ignore litellm params here if param in kwargs: @@ -241,8 +278,8 @@ class Cache: model_group = metadata.get("model_group", None) caching_groups = metadata.get("caching_groups", None) if caching_groups: - for group in caching_groups: - if model_group in group: + for group in caching_groups: + if model_group in group: caching_group = group break if litellm_params is not None: @@ -251,23 +288,34 @@ class Cache: model_group = metadata.get("model_group", None) caching_groups = metadata.get("caching_groups", None) if caching_groups: - for group in caching_groups: - if model_group in group: + for group in caching_groups: + if model_group in group: caching_group = group break - param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"] + param_value = ( + caching_group or model_group or kwargs[param] + ) # use caching_group, if set then model_group if it exists, else use kwargs["model"] else: if kwargs[param] is None: - continue # ignore None params + continue # ignore None params param_value = kwargs[param] - cache_key+= f"{str(param)}: {str(param_value)}" + cache_key += f"{str(param)}: {str(param_value)}" print_verbose(f"\nCreated cache key: {cache_key}") return cache_key def generate_streaming_content(self, content): chunk_size = 5 # Adjust the chunk size as needed for i in range(0, len(content), chunk_size): - yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i + chunk_size]}}]} + yield { + "choices": [ + { + "delta": { + "role": "assistant", + "content": content[i : i + chunk_size], + } + } + ] + } time.sleep(0.02) def get_cache(self, *args, **kwargs): @@ -319,4 +367,4 @@ class Cache: pass async def _async_add_cache(self, result, *args, **kwargs): - self.add_cache(result, *args, **kwargs) \ No newline at end of file + self.add_cache(result, *args, **kwargs) diff --git a/litellm/deprecated_litellm_server/__init__.py b/litellm/deprecated_litellm_server/__init__.py index 019bc5a11..54b9216d9 100644 --- a/litellm/deprecated_litellm_server/__init__.py +++ b/litellm/deprecated_litellm_server/__init__.py @@ -1,2 +1,2 @@ # from .main import * -# from .server_utils import * \ No newline at end of file +# from .server_utils import * diff --git a/litellm/deprecated_litellm_server/main.py b/litellm/deprecated_litellm_server/main.py index 11f011db3..966d2ed19 100644 --- a/litellm/deprecated_litellm_server/main.py +++ b/litellm/deprecated_litellm_server/main.py @@ -33,7 +33,7 @@ # llm_model_list: Optional[list] = None # server_settings: Optional[dict] = None -# set_callbacks() # sets litellm callbacks for logging if they exist in the environment +# set_callbacks() # sets litellm callbacks for logging if they exist in the environment # if "CONFIG_FILE_PATH" in os.environ: # llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH")) @@ -44,7 +44,7 @@ # @router.get("/models") # if project requires model list # def model_list(): # all_models = litellm.utils.get_valid_models() -# if llm_model_list: +# if llm_model_list: # all_models += llm_model_list # return dict( # data=[ @@ -79,8 +79,8 @@ # @router.post("/v1/embeddings") # @router.post("/embeddings") # async def embedding(request: Request): -# try: -# data = await request.json() +# try: +# data = await request.json() # # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers # if os.getenv("AUTH_STRATEGY", None) == "DYNAMIC" and "authorization" in request.headers: # if users pass LLM api keys as part of header # api_key = request.headers.get("authorization") @@ -106,13 +106,13 @@ # data = await request.json() # server_model = server_settings.get("completion_model", None) if server_settings else None # data["model"] = server_model or model or data["model"] -# ## CHECK KEYS ## +# ## CHECK KEYS ## # # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers # # env_validation = litellm.validate_environment(model=data["model"]) # # if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header # # if "authorization" in request.headers: # # api_key = request.headers.get("authorization") -# # elif "api-key" in request.headers: +# # elif "api-key" in request.headers: # # api_key = request.headers.get("api-key") # # print(f"api_key in headers: {api_key}") # # if " " in api_key: @@ -122,11 +122,11 @@ # # api_key = api_key # # data["api_key"] = api_key # # print(f"api_key in data: {api_key}") -# ## CHECK CONFIG ## +# ## CHECK CONFIG ## # if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]: -# for m in llm_model_list: -# if data["model"] == m["model_name"]: -# for key, value in m["litellm_params"].items(): +# for m in llm_model_list: +# if data["model"] == m["model_name"]: +# for key, value in m["litellm_params"].items(): # data[key] = value # break # response = litellm.completion( @@ -145,21 +145,21 @@ # @router.post("/router/completions") # async def router_completion(request: Request): # global llm_router -# try: +# try: # data = await request.json() -# if "model_list" in data: +# if "model_list" in data: # llm_router = litellm.Router(model_list=data.pop("model_list")) -# if llm_router is None: +# if llm_router is None: # raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body") - + # # openai.ChatCompletion.create replacement -# response = await llm_router.acompletion(model="gpt-3.5-turbo", +# response = await llm_router.acompletion(model="gpt-3.5-turbo", # messages=[{"role": "user", "content": "Hey, how's it going?"}]) # if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses # return StreamingResponse(data_generator(response), media_type='text/event-stream') # return response -# except Exception as e: +# except Exception as e: # error_traceback = traceback.format_exc() # error_msg = f"{str(e)}\n\n{error_traceback}" # return {"error": error_msg} @@ -167,11 +167,11 @@ # @router.post("/router/embedding") # async def router_embedding(request: Request): # global llm_router -# try: +# try: # data = await request.json() -# if "model_list" in data: +# if "model_list" in data: # llm_router = litellm.Router(model_list=data.pop("model_list")) -# if llm_router is None: +# if llm_router is None: # raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body") # response = await llm_router.aembedding(model="gpt-3.5-turbo", # type: ignore @@ -180,7 +180,7 @@ # if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses # return StreamingResponse(data_generator(response), media_type='text/event-stream') # return response -# except Exception as e: +# except Exception as e: # error_traceback = traceback.format_exc() # error_msg = f"{str(e)}\n\n{error_traceback}" # return {"error": error_msg} @@ -190,4 +190,4 @@ # return "LiteLLM: RUNNING" -# app.include_router(router) \ No newline at end of file +# app.include_router(router) diff --git a/litellm/deprecated_litellm_server/server_utils.py b/litellm/deprecated_litellm_server/server_utils.py index 209acc8b9..75e5aa4d7 100644 --- a/litellm/deprecated_litellm_server/server_utils.py +++ b/litellm/deprecated_litellm_server/server_utils.py @@ -3,7 +3,7 @@ # import dotenv # dotenv.load_dotenv() # load env variables -# def print_verbose(print_statement): +# def print_verbose(print_statement): # pass # def get_package_version(package_name): @@ -27,32 +27,31 @@ # def set_callbacks(): # ## LOGGING -# if len(os.getenv("SET_VERBOSE", "")) > 0: -# if os.getenv("SET_VERBOSE") == "True": +# if len(os.getenv("SET_VERBOSE", "")) > 0: +# if os.getenv("SET_VERBOSE") == "True": # litellm.set_verbose = True # print_verbose("\033[92mLiteLLM: Switched on verbose logging\033[0m") -# else: +# else: # litellm.set_verbose = False # ### LANGFUSE # if (len(os.getenv("LANGFUSE_PUBLIC_KEY", "")) > 0 and len(os.getenv("LANGFUSE_SECRET_KEY", ""))) > 0 or len(os.getenv("LANGFUSE_HOST", "")) > 0: -# litellm.success_callback = ["langfuse"] +# litellm.success_callback = ["langfuse"] # print_verbose("\033[92mLiteLLM: Switched on Langfuse feature\033[0m") - -# ## CACHING + +# ## CACHING # ### REDIS -# # if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0: +# # if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0: # # print(f"redis host: {os.getenv('REDIS_HOST')}; redis port: {os.getenv('REDIS_PORT')}; password: {os.getenv('REDIS_PASSWORD')}") # # from litellm.caching import Cache # # litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) # # print("\033[92mLiteLLM: Switched on Redis caching\033[0m") - # def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]='/app/config.yaml'): # config = {} -# server_settings = {} -# try: +# server_settings = {} +# try: # if os.path.exists(config_file_path): # type: ignore # with open(config_file_path, 'r') as file: # type: ignore # config = yaml.safe_load(file) @@ -63,24 +62,24 @@ # ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral') # server_settings = config.get("server_settings", None) -# if server_settings: +# if server_settings: # server_settings = server_settings # ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) # litellm_settings = config.get('litellm_settings', None) -# if litellm_settings: -# for key, value in litellm_settings.items(): +# if litellm_settings: +# for key, value in litellm_settings.items(): # setattr(litellm, key, value) # ## MODEL LIST # model_list = config.get('model_list', None) -# if model_list: +# if model_list: # router = litellm.Router(model_list=model_list) - + # ## ENVIRONMENT VARIABLES # environment_variables = config.get('environment_variables', None) -# if environment_variables: -# for key, value in environment_variables.items(): +# if environment_variables: +# for key, value in environment_variables.items(): # os.environ[key] = value # return router, model_list, server_settings diff --git a/litellm/exceptions.py b/litellm/exceptions.py index ae714cfed..3898a5683 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -16,11 +16,11 @@ from openai import ( RateLimitError, APIStatusError, OpenAIError, - APIError, - APITimeoutError, - APIConnectionError, + APIError, + APITimeoutError, + APIConnectionError, APIResponseValidationError, - UnprocessableEntityError + UnprocessableEntityError, ) import httpx @@ -32,11 +32,10 @@ class AuthenticationError(AuthenticationError): # type: ignore self.llm_provider = llm_provider self.model = model super().__init__( - self.message, - response=response, - body=None + self.message, response=response, body=None ) # Call the base class constructor with the parameters it needs + # raise when invalid models passed, example gpt-8 class NotFoundError(NotFoundError): # type: ignore def __init__(self, message, model, llm_provider, response: httpx.Response): @@ -45,9 +44,7 @@ class NotFoundError(NotFoundError): # type: ignore self.model = model self.llm_provider = llm_provider super().__init__( - self.message, - response=response, - body=None + self.message, response=response, body=None ) # Call the base class constructor with the parameters it needs @@ -58,23 +55,21 @@ class BadRequestError(BadRequestError): # type: ignore self.model = model self.llm_provider = llm_provider super().__init__( - self.message, - response=response, - body=None + self.message, response=response, body=None ) # Call the base class constructor with the parameters it needs -class UnprocessableEntityError(UnprocessableEntityError): # type: ignore + +class UnprocessableEntityError(UnprocessableEntityError): # type: ignore def __init__(self, message, model, llm_provider, response: httpx.Response): self.status_code = 422 self.message = message self.model = model self.llm_provider = llm_provider super().__init__( - self.message, - response=response, - body=None + self.message, response=response, body=None ) # Call the base class constructor with the parameters it needs + class Timeout(APITimeoutError): # type: ignore def __init__(self, message, model, llm_provider): self.status_code = 408 @@ -86,6 +81,7 @@ class Timeout(APITimeoutError): # type: ignore request=request ) # Call the base class constructor with the parameters it needs + class RateLimitError(RateLimitError): # type: ignore def __init__(self, message, llm_provider, model, response: httpx.Response): self.status_code = 429 @@ -93,11 +89,10 @@ class RateLimitError(RateLimitError): # type: ignore self.llm_provider = llm_provider self.modle = model super().__init__( - self.message, - response=response, - body=None + self.message, response=response, body=None ) # Call the base class constructor with the parameters it needs + # sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors class ContextWindowExceededError(BadRequestError): # type: ignore def __init__(self, message, model, llm_provider, response: httpx.Response): @@ -106,12 +101,13 @@ class ContextWindowExceededError(BadRequestError): # type: ignore self.model = model self.llm_provider = llm_provider super().__init__( - message=self.message, - model=self.model, # type: ignore - llm_provider=self.llm_provider, # type: ignore - response=response + message=self.message, + model=self.model, # type: ignore + llm_provider=self.llm_provider, # type: ignore + response=response, ) # Call the base class constructor with the parameters it needs + class ServiceUnavailableError(APIStatusError): # type: ignore def __init__(self, message, llm_provider, model, response: httpx.Response): self.status_code = 503 @@ -119,50 +115,42 @@ class ServiceUnavailableError(APIStatusError): # type: ignore self.llm_provider = llm_provider self.model = model super().__init__( - self.message, - response=response, - body=None + self.message, response=response, body=None ) # Call the base class constructor with the parameters it needs # raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401 -class APIError(APIError): # type: ignore - def __init__(self, status_code, message, llm_provider, model, request: httpx.Request): - self.status_code = status_code +class APIError(APIError): # type: ignore + def __init__( + self, status_code, message, llm_provider, model, request: httpx.Request + ): + self.status_code = status_code self.message = message self.llm_provider = llm_provider self.model = model - super().__init__( - self.message, - request=request, # type: ignore - body=None - ) + super().__init__(self.message, request=request, body=None) # type: ignore + # raised if an invalid request (not get, delete, put, post) is made -class APIConnectionError(APIConnectionError): # type: ignore +class APIConnectionError(APIConnectionError): # type: ignore def __init__(self, message, llm_provider, model, request: httpx.Request): self.message = message self.llm_provider = llm_provider self.model = model self.status_code = 500 - super().__init__( - message=self.message, - request=request - ) + super().__init__(message=self.message, request=request) + # raised if an invalid request (not get, delete, put, post) is made -class APIResponseValidationError(APIResponseValidationError): # type: ignore +class APIResponseValidationError(APIResponseValidationError): # type: ignore def __init__(self, message, llm_provider, model): self.message = message self.llm_provider = llm_provider self.model = model request = httpx.Request(method="POST", url="https://api.openai.com/v1") response = httpx.Response(status_code=500, request=request) - super().__init__( - response=response, - body=None, - message=message - ) + super().__init__(response=response, body=None, message=message) + class OpenAIError(OpenAIError): # type: ignore def __init__(self, original_exception): @@ -176,6 +164,7 @@ class OpenAIError(OpenAIError): # type: ignore ) self.llm_provider = "openai" + class BudgetExceededError(Exception): def __init__(self, current_cost, max_budget): self.current_cost = current_cost @@ -183,7 +172,8 @@ class BudgetExceededError(Exception): message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}" super().__init__(message) -## DEPRECATED ## + +## DEPRECATED ## class InvalidRequestError(BadRequestError): # type: ignore def __init__(self, message, model, llm_provider): self.status_code = 400 diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 64c5fc4cf..316e48aed 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -5,32 +5,33 @@ import requests from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache from typing import Literal + dotenv.load_dotenv() # Loading env variables using dotenv import traceback -class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class +class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class # Class variables or attributes def __init__(self): pass - def log_pre_api_call(self, model, messages, kwargs): + def log_pre_api_call(self, model, messages, kwargs): pass - def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): pass - + def log_stream_event(self, kwargs, response_obj, start_time, end_time): pass - def log_success_event(self, kwargs, response_obj, start_time, end_time): + def log_success_event(self, kwargs, response_obj, start_time, end_time): pass - def log_failure_event(self, kwargs, response_obj, start_time, end_time): + def log_failure_event(self, kwargs, response_obj, start_time, end_time): pass - #### ASYNC #### - + #### ASYNC #### + async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): pass @@ -43,81 +44,87 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): pass - #### CALL HOOKS - proxy only #### + #### CALL HOOKS - proxy only #### """ Control the modify incoming / outgoung data before calling the model """ - async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]): + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal["completion", "embeddings"], + ): pass - - async def async_post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth): + + async def async_post_call_failure_hook( + self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth + ): pass #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): - try: + try: kwargs["model"] = model kwargs["messages"] = messages kwargs["log_event_type"] = "pre_api_call" callback_func( kwargs, ) - print_verbose( - f"Custom Logger - model call details: {kwargs}" - ) - except: + print_verbose(f"Custom Logger - model call details: {kwargs}") + except: traceback.print_exc() print_verbose(f"Custom Logger Error - {traceback.format_exc()}") - async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func): - try: + async def async_log_input_event( + self, model, messages, kwargs, print_verbose, callback_func + ): + try: kwargs["model"] = model kwargs["messages"] = messages kwargs["log_event_type"] = "pre_api_call" await callback_func( kwargs, ) - print_verbose( - f"Custom Logger - model call details: {kwargs}" - ) - except: + print_verbose(f"Custom Logger - model call details: {kwargs}") + except: traceback.print_exc() print_verbose(f"Custom Logger Error - {traceback.format_exc()}") - - def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func): + def log_event( + self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func + ): # Method definition try: kwargs["log_event_type"] = "post_api_call" callback_func( - kwargs, # kwargs to func + kwargs, # kwargs to func response_obj, start_time, end_time, ) - print_verbose( - f"Custom Logger - final response object: {response_obj}" - ) + print_verbose(f"Custom Logger - final response object: {response_obj}") except: # traceback.print_exc() print_verbose(f"Custom Logger Error - {traceback.format_exc()}") pass - - async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func): + + async def async_log_event( + self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func + ): # Method definition try: kwargs["log_event_type"] = "post_api_call" await callback_func( - kwargs, # kwargs to func + kwargs, # kwargs to func response_obj, start_time, end_time, ) - print_verbose( - f"Custom Logger - final response object: {response_obj}" - ) + print_verbose(f"Custom Logger - final response object: {response_obj}") except: # traceback.print_exc() print_verbose(f"Custom Logger Error - {traceback.format_exc()}") - pass \ No newline at end of file + pass diff --git a/litellm/integrations/dynamodb.py b/litellm/integrations/dynamodb.py index c025a0edc..2ed6c3f9f 100644 --- a/litellm/integrations/dynamodb.py +++ b/litellm/integrations/dynamodb.py @@ -10,19 +10,28 @@ import datetime, subprocess, sys import litellm, uuid from litellm._logging import print_verbose + class DyanmoDBLogger: # Class variables or attributes def __init__(self): # Instance variables import boto3 - self.dynamodb = boto3.resource('dynamodb', region_name=os.environ["AWS_REGION_NAME"]) + + self.dynamodb = boto3.resource( + "dynamodb", region_name=os.environ["AWS_REGION_NAME"] + ) if litellm.dynamodb_table_name is None: - raise ValueError("LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=`") + raise ValueError( + "LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=`" + ) self.table_name = litellm.dynamodb_table_name - async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): + async def _async_log_event( + self, kwargs, response_obj, start_time, end_time, print_verbose + ): self.log_event(kwargs, response_obj, start_time, end_time, print_verbose) + def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): try: print_verbose( @@ -32,7 +41,9 @@ class DyanmoDBLogger: # construct payload to send to DynamoDB # follows the same params as langfuse.py litellm_params = kwargs.get("litellm_params", {}) - metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None messages = kwargs.get("messages") optional_params = kwargs.get("optional_params", {}) call_type = kwargs.get("call_type", "litellm.completion") @@ -51,7 +62,7 @@ class DyanmoDBLogger: "messages": messages, "response": response_obj, "usage": usage, - "metadata": metadata + "metadata": metadata, } # Ensure everything in the payload is converted to str @@ -62,9 +73,8 @@ class DyanmoDBLogger: # non blocking if it can't cast to a str pass - print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}") - + # put data in dyanmo DB table = self.dynamodb.Table(self.table_name) # Assuming log_data is a dictionary with log information @@ -79,4 +89,4 @@ class DyanmoDBLogger: except: traceback.print_exc() print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}") - pass \ No newline at end of file + pass diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index 2c478871b..d418aad2a 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -64,7 +64,9 @@ class LangFuseLogger: # end of processing langfuse ######################## input = prompt output = response_obj["choices"][0]["message"].json() - print(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}") + print( + f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}" + ) self._log_langfuse_v2( user_id, metadata, @@ -171,7 +173,6 @@ class LangFuseLogger: user_id=user_id, ) - trace.generation( name=metadata.get("generation_name", "litellm-completion"), startTime=start_time, diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index de9dd2f71..d951d6924 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -8,25 +8,27 @@ from datetime import datetime dotenv.load_dotenv() # Loading env variables using dotenv import traceback + class LangsmithLogger: # Class variables or attributes def __init__(self): self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY") - def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): # Method definition # inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb metadata = {} if "litellm_params" in kwargs: metadata = kwargs["litellm_params"].get("metadata", {}) - # set project name and run_name for langsmith logging + # set project name and run_name for langsmith logging # users can pass project_name and run name to litellm.completion() # Example: litellm.completion(model, messages, metadata={"project_name": "my-litellm-project", "run_name": "my-langsmith-run"}) # if not set litellm will use default project_name = litellm-completion, run_name = LLMRun project_name = metadata.get("project_name", "litellm-completion") run_name = metadata.get("run_name", "LLMRun") - print_verbose(f"Langsmith Logging - project_name: {project_name}, run_name {run_name}") + print_verbose( + f"Langsmith Logging - project_name: {project_name}, run_name {run_name}" + ) try: print_verbose( f"Langsmith Logging - Enters logging function for model {kwargs}" @@ -34,6 +36,7 @@ class LangsmithLogger: import requests import datetime from datetime import timezone + try: start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat() end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat() @@ -45,7 +48,7 @@ class LangsmithLogger: new_kwargs = {} for key in kwargs: value = kwargs[key] - if key == "start_time" or key =="end_time": + if key == "start_time" or key == "end_time": pass elif type(value) != dict: new_kwargs[key] = value @@ -54,18 +57,14 @@ class LangsmithLogger: "https://api.smith.langchain.com/runs", json={ "name": run_name, - "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" - "inputs": { - **new_kwargs - }, + "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" + "inputs": {**new_kwargs}, "outputs": response_obj.json(), "session_name": project_name, "start_time": start_time, "end_time": end_time, }, - headers={ - "x-api-key": self.langsmith_api_key - } + headers={"x-api-key": self.langsmith_api_key}, ) print_verbose( f"Langsmith Layer Logging - final response object: {response_obj}" diff --git a/litellm/integrations/litedebugger.py b/litellm/integrations/litedebugger.py index d69602d5c..0b4ef03ad 100644 --- a/litellm/integrations/litedebugger.py +++ b/litellm/integrations/litedebugger.py @@ -1,6 +1,7 @@ import requests, traceback, json, os import types + class LiteDebugger: user_email = None dashboard_url = None @@ -12,9 +13,15 @@ class LiteDebugger: def validate_environment(self, email): try: - self.user_email = (email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")) - if self.user_email == None: # if users are trying to use_client=True but token not set - raise ValueError("litellm.use_client = True but no token or email passed. Please set it in litellm.token") + self.user_email = ( + email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL") + ) + if ( + self.user_email == None + ): # if users are trying to use_client=True but token not set + raise ValueError( + "litellm.use_client = True but no token or email passed. Please set it in litellm.token" + ) self.dashboard_url = "https://admin.litellm.ai/" + self.user_email try: print( @@ -42,7 +49,9 @@ class LiteDebugger: litellm_params, optional_params, ): - print_verbose(f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}") + print_verbose( + f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}" + ) try: print_verbose( f"LiteLLMDebugger: Logging - Enters input logging function for model {model}" @@ -56,7 +65,11 @@ class LiteDebugger: updated_litellm_params = remove_key_value(litellm_params, "logger_fn") if call_type == "embedding": - for message in messages: # assuming the input is a list as required by the embedding function + for ( + message + ) in ( + messages + ): # assuming the input is a list as required by the embedding function litellm_data_obj = { "model": model, "messages": [{"role": "user", "content": message}], @@ -79,7 +92,9 @@ class LiteDebugger: elif call_type == "completion": litellm_data_obj = { "model": model, - "messages": messages if isinstance(messages, list) else [{"role": "user", "content": messages}], + "messages": messages + if isinstance(messages, list) + else [{"role": "user", "content": messages}], "end_user": end_user, "status": "initiated", "litellm_call_id": litellm_call_id, @@ -95,20 +110,30 @@ class LiteDebugger: headers={"content-type": "application/json"}, data=json.dumps(litellm_data_obj), ) - print_verbose(f"LiteDebugger: completion api response - {response.text}") + print_verbose( + f"LiteDebugger: completion api response - {response.text}" + ) except: print_verbose( f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}" ) pass - def post_call_log_event(self, original_response, litellm_call_id, print_verbose, call_type, stream): - print_verbose(f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}") + def post_call_log_event( + self, original_response, litellm_call_id, print_verbose, call_type, stream + ): + print_verbose( + f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}" + ) try: if call_type == "embedding": litellm_data_obj = { "status": "received", - "additional_details": {"original_response": str(original_response["data"][0]["embedding"][:5])}, # don't store the entire vector + "additional_details": { + "original_response": str( + original_response["data"][0]["embedding"][:5] + ) + }, # don't store the entire vector "litellm_call_id": litellm_call_id, "user_email": self.user_email, } @@ -122,7 +147,11 @@ class LiteDebugger: elif call_type == "completion" and stream: litellm_data_obj = { "status": "received", - "additional_details": {"original_response": "Streamed response" if isinstance(original_response, types.GeneratorType) else original_response}, + "additional_details": { + "original_response": "Streamed response" + if isinstance(original_response, types.GeneratorType) + else original_response + }, "litellm_call_id": litellm_call_id, "user_email": self.user_email, } @@ -146,10 +175,12 @@ class LiteDebugger: end_time, litellm_call_id, print_verbose, - call_type, - stream = False + call_type, + stream=False, ): - print_verbose(f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}") + print_verbose( + f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}" + ) try: print_verbose( f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}" @@ -186,7 +217,7 @@ class LiteDebugger: data=json.dumps(litellm_data_obj), ) elif call_type == "completion" and stream == True: - if len(response_obj["content"]) > 0: # don't log the empty strings + if len(response_obj["content"]) > 0: # don't log the empty strings litellm_data_obj = { "response_time": response_time, "total_cost": total_cost, diff --git a/litellm/integrations/prompt_layer.py b/litellm/integrations/prompt_layer.py index 4167ea60f..4bf2089de 100644 --- a/litellm/integrations/prompt_layer.py +++ b/litellm/integrations/prompt_layer.py @@ -18,19 +18,17 @@ class PromptLayerLogger: # Method definition try: new_kwargs = {} - new_kwargs['model'] = kwargs['model'] - new_kwargs['messages'] = kwargs['messages'] + new_kwargs["model"] = kwargs["model"] + new_kwargs["messages"] = kwargs["messages"] # add kwargs["optional_params"] to new_kwargs for optional_param in kwargs["optional_params"]: new_kwargs[optional_param] = kwargs["optional_params"][optional_param] - print_verbose( f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}" ) - request_response = requests.post( "https://api.promptlayer.com/rest/track-request", json={ @@ -51,8 +49,8 @@ class PromptLayerLogger: f"Prompt Layer Logging: success - final response object: {request_response.text}" ) response_json = request_response.json() - if "success" not in request_response.json(): - raise Exception("Promptlayer did not successfully log the response!") + if "success" not in request_response.json(): + raise Exception("Promptlayer did not successfully log the response!") if "request_id" in response_json: print(kwargs["litellm_params"]["metadata"]) @@ -62,10 +60,12 @@ class PromptLayerLogger: json={ "request_id": response_json["request_id"], "api_key": self.key, - "metadata": kwargs["litellm_params"]["metadata"] + "metadata": kwargs["litellm_params"]["metadata"], }, ) - print_verbose(f"Prompt Layer Logging: success - metadata post response object: {response.text}") + print_verbose( + f"Prompt Layer Logging: success - metadata post response object: {response.text}" + ) except: print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}") diff --git a/litellm/integrations/supabase.py b/litellm/integrations/supabase.py index bca48a451..a99e4abc4 100644 --- a/litellm/integrations/supabase.py +++ b/litellm/integrations/supabase.py @@ -9,6 +9,7 @@ import traceback import datetime, subprocess, sys import litellm + class Supabase: # Class variables or attributes supabase_table_name = "request_logs" diff --git a/litellm/integrations/traceloop.py b/litellm/integrations/traceloop.py index c188ac8bd..bbdb9a1b0 100644 --- a/litellm/integrations/traceloop.py +++ b/litellm/integrations/traceloop.py @@ -2,6 +2,7 @@ class TraceloopLogger: def __init__(self): from traceloop.sdk.tracing.tracing import TracerWrapper from traceloop.sdk import Traceloop + Traceloop.init(app_name="Litellm-Server", disable_batch=True) self.tracer_wrapper = TracerWrapper() @@ -29,15 +30,18 @@ class TraceloopLogger: ) if "stop" in optional_params: span.set_attribute( - SpanAttributes.LLM_CHAT_STOP_SEQUENCES, optional_params.get("stop") + SpanAttributes.LLM_CHAT_STOP_SEQUENCES, + optional_params.get("stop"), ) if "frequency_penalty" in optional_params: span.set_attribute( - SpanAttributes.LLM_FREQUENCY_PENALTY, optional_params.get("frequency_penalty") + SpanAttributes.LLM_FREQUENCY_PENALTY, + optional_params.get("frequency_penalty"), ) if "presence_penalty" in optional_params: span.set_attribute( - SpanAttributes.LLM_PRESENCE_PENALTY, optional_params.get("presence_penalty") + SpanAttributes.LLM_PRESENCE_PENALTY, + optional_params.get("presence_penalty"), ) if "top_p" in optional_params: span.set_attribute( @@ -45,7 +49,10 @@ class TraceloopLogger: ) if "tools" in optional_params or "functions" in optional_params: span.set_attribute( - SpanAttributes.LLM_REQUEST_FUNCTIONS, optional_params.get("tools", optional_params.get("functions")) + SpanAttributes.LLM_REQUEST_FUNCTIONS, + optional_params.get( + "tools", optional_params.get("functions") + ), ) if "user" in optional_params: span.set_attribute( @@ -53,7 +60,8 @@ class TraceloopLogger: ) if "max_tokens" in optional_params: span.set_attribute( - SpanAttributes.LLM_REQUEST_MAX_TOKENS, kwargs.get("max_tokens") + SpanAttributes.LLM_REQUEST_MAX_TOKENS, + kwargs.get("max_tokens"), ) if "temperature" in optional_params: span.set_attribute( diff --git a/litellm/integrations/weights_biases.py b/litellm/integrations/weights_biases.py index c571eca3a..53e6070a5 100644 --- a/litellm/integrations/weights_biases.py +++ b/litellm/integrations/weights_biases.py @@ -1,4 +1,4 @@ -imported_openAIResponse=True +imported_openAIResponse = True try: import io import logging @@ -12,15 +12,12 @@ try: else: from typing_extensions import Literal, Protocol - logger = logging.getLogger(__name__) - K = TypeVar("K", bound=str) V = TypeVar("V") - - class OpenAIResponse(Protocol[K, V]): # type: ignore + class OpenAIResponse(Protocol[K, V]): # type: ignore # contains a (known) object attribute object: Literal["chat.completion", "edit", "text_completion"] @@ -30,7 +27,6 @@ try: def get(self, key: K, default: Optional[V] = None) -> Optional[V]: ... # pragma: no cover - class OpenAIRequestResponseResolver: def __call__( self, @@ -44,7 +40,9 @@ try: elif response["object"] == "text_completion": return self._resolve_completion(request, response, time_elapsed) elif response["object"] == "chat.completion": - return self._resolve_chat_completion(request, response, time_elapsed) + return self._resolve_chat_completion( + request, response, time_elapsed + ) else: logger.info(f"Unknown OpenAI response object: {response['object']}") except Exception as e: @@ -113,7 +111,8 @@ try: """Resolves the request and response objects for `openai.Completion`.""" request_str = f"\n\n**Prompt**: {request['prompt']}\n" choices = [ - f"\n\n**Completion**: {choice['text']}\n" for choice in response["choices"] + f"\n\n**Completion**: {choice['text']}\n" + for choice in response["choices"] ] return self._request_response_result_to_trace( @@ -167,9 +166,9 @@ try: ] trace = self.results_to_trace_tree(request, response, results, time_elapsed) return trace -except: - imported_openAIResponse=False +except: + imported_openAIResponse = False #### What this does #### @@ -182,29 +181,34 @@ from datetime import datetime dotenv.load_dotenv() # Loading env variables using dotenv import traceback + class WeightsBiasesLogger: # Class variables or attributes def __init__(self): try: import wandb except: - raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m") - if imported_openAIResponse==False: - raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m") + raise Exception( + "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m" + ) + if imported_openAIResponse == False: + raise Exception( + "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m" + ) self.resolver = OpenAIRequestResponseResolver() def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): # Method definition import wandb - + try: - print_verbose( - f"W&B Logging - Enters logging function for model {kwargs}" - ) + print_verbose(f"W&B Logging - Enters logging function for model {kwargs}") run = wandb.init() print_verbose(response_obj) - trace = self.resolver(kwargs, response_obj, (end_time-start_time).total_seconds()) + trace = self.resolver( + kwargs, response_obj, (end_time - start_time).total_seconds() + ) if trace is not None: run.log({"trace": trace}) diff --git a/litellm/llms/ai21.py b/litellm/llms/ai21.py index e20997751..73d5afebe 100644 --- a/litellm/llms/ai21.py +++ b/litellm/llms/ai21.py @@ -7,76 +7,93 @@ from typing import Callable, Optional from litellm.utils import ModelResponse, Choices, Message import litellm + class AI21Error(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url="https://api.ai21.com/studio/v1/") + self.request = httpx.Request( + method="POST", url="https://api.ai21.com/studio/v1/" + ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class AI21Config(): + +class AI21Config: """ Reference: https://docs.ai21.com/reference/j2-complete-ref The class `AI21Config` provides configuration for the AI21's API interface. Below are the parameters: - `numResults` (int32): Number of completions to sample and return. Optional, default is 1. If the temperature is greater than 0 (non-greedy decoding), a value greater than 1 can be meaningful. - + - `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`. - + - `minTokens` (int32): The minimum number of tokens to generate per result. Optional, default is 0. If `stopSequences` are given, they are ignored until `minTokens` are generated. - + - `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding. - + - `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass. - + - `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional. - + - `topKReturn` (int32): Range between 0 to 10, including both. Optional, default is 0. Specifies the top-K alternative tokens to return. A non-zero value includes the string representations and log-probabilities for each of the top-K alternatives at each position. - + - `frequencyPenalty` (object): Placeholder for frequency penalty object. - + - `presencePenalty` (object): Placeholder for presence penalty object. - + - `countPenalty` (object): Placeholder for count penalty object. """ - numResults: Optional[int]=None - maxTokens: Optional[int]=None - minTokens: Optional[int]=None - temperature: Optional[float]=None - topP: Optional[float]=None - stopSequences: Optional[list]=None - topKReturn: Optional[int]=None - frequencePenalty: Optional[dict]=None - presencePenalty: Optional[dict]=None - countPenalty: Optional[dict]=None - def __init__(self, - numResults: Optional[int]=None, - maxTokens: Optional[int]=None, - minTokens: Optional[int]=None, - temperature: Optional[float]=None, - topP: Optional[float]=None, - stopSequences: Optional[list]=None, - topKReturn: Optional[int]=None, - frequencePenalty: Optional[dict]=None, - presencePenalty: Optional[dict]=None, - countPenalty: Optional[dict]=None) -> None: + numResults: Optional[int] = None + maxTokens: Optional[int] = None + minTokens: Optional[int] = None + temperature: Optional[float] = None + topP: Optional[float] = None + stopSequences: Optional[list] = None + topKReturn: Optional[int] = None + frequencePenalty: Optional[dict] = None + presencePenalty: Optional[dict] = None + countPenalty: Optional[dict] = None + + def __init__( + self, + numResults: Optional[int] = None, + maxTokens: Optional[int] = None, + minTokens: Optional[int] = None, + temperature: Optional[float] = None, + topP: Optional[float] = None, + stopSequences: Optional[list] = None, + topKReturn: Optional[int] = None, + frequencePenalty: Optional[dict] = None, + presencePenalty: Optional[dict] = None, + countPenalty: Optional[dict] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} - + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } def validate_environment(api_key): @@ -91,6 +108,7 @@ def validate_environment(api_key): } return headers + def completion( model: str, messages: list, @@ -110,20 +128,18 @@ def completion( for message in messages: if "role" in message: if message["role"] == "user": - prompt += ( - f"{message['content']}" - ) + prompt += f"{message['content']}" else: - prompt += ( - f"{message['content']}" - ) + prompt += f"{message['content']}" else: prompt += f"{message['content']}" - + ## Load Config - config = litellm.AI21Config.get_config() - for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.AI21Config.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v data = { @@ -134,29 +150,26 @@ def completion( ## LOGGING logging_obj.pre_call( - input=prompt, - api_key=api_key, - additional_args={"complete_input_dict": data}, - ) + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) ## COMPLETION CALL response = requests.post( api_base + model + "/complete", headers=headers, data=json.dumps(data) ) if response.status_code != 200: - raise AI21Error( - status_code=response.status_code, - message=response.text - ) + raise AI21Error(status_code=response.status_code, message=response.text) if "stream" in optional_params and optional_params["stream"] == True: return response.iter_lines() else: ## LOGGING logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) ## RESPONSE OBJECT completion_response = response.json() try: @@ -164,18 +177,22 @@ def completion( for idx, item in enumerate(completion_response["completions"]): if len(item["data"]["text"]) > 0: message_obj = Message(content=item["data"]["text"]) - else: + else: message_obj = Message(content=None) - choice_obj = Choices(finish_reason=item["finishReason"]["reason"], index=idx+1, message=message_obj) + choice_obj = Choices( + finish_reason=item["finishReason"]["reason"], + index=idx + 1, + message=message_obj, + ) choices_list.append(choice_obj) model_response["choices"] = choices_list except Exception as e: - raise AI21Error(message=traceback.format_exc(), status_code=response.status_code) + raise AI21Error( + message=traceback.format_exc(), status_code=response.status_code + ) - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len( - encoding.encode(prompt) - ) + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( encoding.encode(model_response["choices"][0]["message"].get("content")) ) @@ -189,6 +206,7 @@ def completion( } return model_response + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/aleph_alpha.py b/litellm/llms/aleph_alpha.py index 9bceec51b..7168e7369 100644 --- a/litellm/llms/aleph_alpha.py +++ b/litellm/llms/aleph_alpha.py @@ -8,17 +8,21 @@ import litellm from litellm.utils import ModelResponse, Choices, Message, Usage import httpx + class AlephAlphaError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url="https://api.aleph-alpha.com/complete") + self.request = httpx.Request( + method="POST", url="https://api.aleph-alpha.com/complete" + ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class AlephAlphaConfig(): + +class AlephAlphaConfig: """ Reference: https://docs.aleph-alpha.com/api/complete/ @@ -42,13 +46,13 @@ class AlephAlphaConfig(): - `repetition_penalties_include_prompt`, `repetition_penalties_include_completion`, `use_multiplicative_presence_penalty`,`use_multiplicative_frequency_penalty`,`use_multiplicative_sequence_penalty` (boolean, nullable; default value: false): Various settings that adjust how the repetition penalties are applied. - - `penalty_bias` (string, nullable): Text used in addition to the penalized tokens for repetition penalties. + - `penalty_bias` (string, nullable): Text used in addition to the penalized tokens for repetition penalties. - `penalty_exceptions` (string[], nullable): Strings that may be generated without penalty. - `penalty_exceptions_include_stop_sequences` (boolean, nullable; default value: true): Include all stop_sequences in penalty_exceptions. - - `best_of` (integer, nullable; default value: 1): The number of completions will be generated on the server side. + - `best_of` (integer, nullable; default value: 1): The number of completions will be generated on the server side. - `n` (integer, nullable; default value: 1): The number of completions to return. @@ -68,87 +72,101 @@ class AlephAlphaConfig(): - `completion_bias_inclusion_first_token_only`, `completion_bias_exclusion_first_token_only` (boolean; default value: false): Consider only the first token for the completion_bias_inclusion/exclusion. - - `contextual_control_threshold` (number, nullable): Control over how similar tokens are controlled. + - `contextual_control_threshold` (number, nullable): Control over how similar tokens are controlled. - `control_log_additive` (boolean; default value: true): Method of applying control to attention scores. """ - maximum_tokens: Optional[int]=litellm.max_tokens # aleph alpha requires max tokens - minimum_tokens: Optional[int]=None - echo: Optional[bool]=None - temperature: Optional[int]=None - top_k: Optional[int]=None - top_p: Optional[int]=None - presence_penalty: Optional[int]=None - frequency_penalty: Optional[int]=None - sequence_penalty: Optional[int]=None - sequence_penalty_min_length: Optional[int]=None - repetition_penalties_include_prompt: Optional[bool]=None - repetition_penalties_include_completion: Optional[bool]=None - use_multiplicative_presence_penalty: Optional[bool]=None - use_multiplicative_frequency_penalty: Optional[bool]=None - use_multiplicative_sequence_penalty: Optional[bool]=None - penalty_bias: Optional[str]=None - penalty_exceptions_include_stop_sequences: Optional[bool]=None - best_of: Optional[int]=None - n: Optional[int]=None - logit_bias: Optional[dict]=None - log_probs: Optional[int]=None - stop_sequences: Optional[list]=None - tokens: Optional[bool]=None - raw_completion: Optional[bool]=None - disable_optimizations: Optional[bool]=None - completion_bias_inclusion: Optional[list]=None - completion_bias_exclusion: Optional[list]=None - completion_bias_inclusion_first_token_only: Optional[bool]=None - completion_bias_exclusion_first_token_only: Optional[bool]=None - contextual_control_threshold: Optional[int]=None - control_log_additive: Optional[bool]=None + maximum_tokens: Optional[ + int + ] = litellm.max_tokens # aleph alpha requires max tokens + minimum_tokens: Optional[int] = None + echo: Optional[bool] = None + temperature: Optional[int] = None + top_k: Optional[int] = None + top_p: Optional[int] = None + presence_penalty: Optional[int] = None + frequency_penalty: Optional[int] = None + sequence_penalty: Optional[int] = None + sequence_penalty_min_length: Optional[int] = None + repetition_penalties_include_prompt: Optional[bool] = None + repetition_penalties_include_completion: Optional[bool] = None + use_multiplicative_presence_penalty: Optional[bool] = None + use_multiplicative_frequency_penalty: Optional[bool] = None + use_multiplicative_sequence_penalty: Optional[bool] = None + penalty_bias: Optional[str] = None + penalty_exceptions_include_stop_sequences: Optional[bool] = None + best_of: Optional[int] = None + n: Optional[int] = None + logit_bias: Optional[dict] = None + log_probs: Optional[int] = None + stop_sequences: Optional[list] = None + tokens: Optional[bool] = None + raw_completion: Optional[bool] = None + disable_optimizations: Optional[bool] = None + completion_bias_inclusion: Optional[list] = None + completion_bias_exclusion: Optional[list] = None + completion_bias_inclusion_first_token_only: Optional[bool] = None + completion_bias_exclusion_first_token_only: Optional[bool] = None + contextual_control_threshold: Optional[int] = None + control_log_additive: Optional[bool] = None - def __init__(self, - maximum_tokens: Optional[int]=None, - minimum_tokens: Optional[int]=None, - echo: Optional[bool]=None, - temperature: Optional[int]=None, - top_k: Optional[int]=None, - top_p: Optional[int]=None, - presence_penalty: Optional[int]=None, - frequency_penalty: Optional[int]=None, - sequence_penalty: Optional[int]=None, - sequence_penalty_min_length: Optional[int]=None, - repetition_penalties_include_prompt: Optional[bool]=None, - repetition_penalties_include_completion: Optional[bool]=None, - use_multiplicative_presence_penalty: Optional[bool]=None, - use_multiplicative_frequency_penalty: Optional[bool]=None, - use_multiplicative_sequence_penalty: Optional[bool]=None, - penalty_bias: Optional[str]=None, - penalty_exceptions_include_stop_sequences: Optional[bool]=None, - best_of: Optional[int]=None, - n: Optional[int]=None, - logit_bias: Optional[dict]=None, - log_probs: Optional[int]=None, - stop_sequences: Optional[list]=None, - tokens: Optional[bool]=None, - raw_completion: Optional[bool]=None, - disable_optimizations: Optional[bool]=None, - completion_bias_inclusion: Optional[list]=None, - completion_bias_exclusion: Optional[list]=None, - completion_bias_inclusion_first_token_only: Optional[bool]=None, - completion_bias_exclusion_first_token_only: Optional[bool]=None, - contextual_control_threshold: Optional[int]=None, - control_log_additive: Optional[bool]=None) -> None: - + def __init__( + self, + maximum_tokens: Optional[int] = None, + minimum_tokens: Optional[int] = None, + echo: Optional[bool] = None, + temperature: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + presence_penalty: Optional[int] = None, + frequency_penalty: Optional[int] = None, + sequence_penalty: Optional[int] = None, + sequence_penalty_min_length: Optional[int] = None, + repetition_penalties_include_prompt: Optional[bool] = None, + repetition_penalties_include_completion: Optional[bool] = None, + use_multiplicative_presence_penalty: Optional[bool] = None, + use_multiplicative_frequency_penalty: Optional[bool] = None, + use_multiplicative_sequence_penalty: Optional[bool] = None, + penalty_bias: Optional[str] = None, + penalty_exceptions_include_stop_sequences: Optional[bool] = None, + best_of: Optional[int] = None, + n: Optional[int] = None, + logit_bias: Optional[dict] = None, + log_probs: Optional[int] = None, + stop_sequences: Optional[list] = None, + tokens: Optional[bool] = None, + raw_completion: Optional[bool] = None, + disable_optimizations: Optional[bool] = None, + completion_bias_inclusion: Optional[list] = None, + completion_bias_exclusion: Optional[list] = None, + completion_bias_inclusion_first_token_only: Optional[bool] = None, + completion_bias_exclusion_first_token_only: Optional[bool] = None, + contextual_control_threshold: Optional[int] = None, + control_log_additive: Optional[bool] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } def validate_environment(api_key): @@ -160,6 +178,7 @@ def validate_environment(api_key): headers["Authorization"] = f"Bearer {api_key}" return headers + def completion( model: str, messages: list, @@ -177,9 +196,11 @@ def completion( headers = validate_environment(api_key) ## Load Config - config = litellm.AlephAlphaConfig.get_config() - for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.AlephAlphaConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v completion_url = api_base @@ -188,21 +209,17 @@ def completion( if "control" in model: # follow the ###Instruction / ###Response format for idx, message in enumerate(messages): if "role" in message: - if idx == 0: # set first message as instruction (required), let later user messages be input + if ( + idx == 0 + ): # set first message as instruction (required), let later user messages be input prompt += f"###Instruction: {message['content']}" else: if message["role"] == "system": - prompt += ( - f"###Instruction: {message['content']}" - ) + prompt += f"###Instruction: {message['content']}" elif message["role"] == "user": - prompt += ( - f"###Input: {message['content']}" - ) + prompt += f"###Input: {message['content']}" else: - prompt += ( - f"###Response: {message['content']}" - ) + prompt += f"###Response: {message['content']}" else: prompt += f"{message['content']}" else: @@ -215,24 +232,27 @@ def completion( ## LOGGING logging_obj.pre_call( - input=prompt, - api_key=api_key, - additional_args={"complete_input_dict": data}, - ) + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) ## COMPLETION CALL response = requests.post( - completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False + completion_url, + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"] if "stream" in optional_params else False, ) if "stream" in optional_params and optional_params["stream"] == True: return response.iter_lines() else: ## LOGGING logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT completion_response = response.json() @@ -247,18 +267,23 @@ def completion( for idx, item in enumerate(completion_response["completions"]): if len(item["completion"]) > 0: message_obj = Message(content=item["completion"]) - else: + else: message_obj = Message(content=None) - choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) + choice_obj = Choices( + finish_reason=item["finish_reason"], + index=idx + 1, + message=message_obj, + ) choices_list.append(choice_obj) model_response["choices"] = choices_list except: - raise AlephAlphaError(message=json.dumps(completion_response), status_code=response.status_code) + raise AlephAlphaError( + message=json.dumps(completion_response), + status_code=response.status_code, + ) - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len( - encoding.encode(prompt) - ) + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( encoding.encode(model_response["choices"][0]["message"]["content"]) ) @@ -268,11 +293,12 @@ def completion( usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage return model_response + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 6b1d50ff8..4df032ba0 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -5,56 +5,76 @@ import requests import time from typing import Callable, Optional from litellm.utils import ModelResponse, Usage -import litellm +import litellm from .prompt_templates.factory import prompt_factory, custom_prompt import httpx + class AnthropicConstants(Enum): HUMAN_PROMPT = "\n\nHuman: " AI_PROMPT = "\n\nAssistant: " + class AnthropicError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url="https://api.anthropic.com/v1/complete") + self.request = httpx.Request( + method="POST", url="https://api.anthropic.com/v1/complete" + ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class AnthropicConfig(): + +class AnthropicConfig: """ Reference: https://docs.anthropic.com/claude/reference/complete_post to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} """ - max_tokens_to_sample: Optional[int]=litellm.max_tokens # anthropic requires a default - stop_sequences: Optional[list]=None - temperature: Optional[int]=None - top_p: Optional[int]=None - top_k: Optional[int]=None - metadata: Optional[dict]=None - def __init__(self, - max_tokens_to_sample: Optional[int]=256, # anthropic requires a default - stop_sequences: Optional[list]=None, - temperature: Optional[int]=None, - top_p: Optional[int]=None, - top_k: Optional[int]=None, - metadata: Optional[dict]=None) -> None: - + max_tokens_to_sample: Optional[ + int + ] = litellm.max_tokens # anthropic requires a default + stop_sequences: Optional[list] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + top_k: Optional[int] = None + metadata: Optional[dict] = None + + def __init__( + self, + max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default + stop_sequences: Optional[list] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + top_k: Optional[int] = None, + metadata: Optional[dict] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } # makes headers for API call @@ -71,6 +91,7 @@ def validate_environment(api_key): } return headers + def completion( model: str, messages: list, @@ -87,21 +108,25 @@ def completion( ): headers = validate_environment(api_key) if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages - ) + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) else: - prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic") - + prompt = prompt_factory( + model=model, messages=messages, custom_llm_provider="anthropic" + ) + ## Load Config - config = litellm.AnthropicConfig.get_config() - for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.AnthropicConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v data = { @@ -116,7 +141,7 @@ def completion( api_key=api_key, additional_args={"complete_input_dict": data, "api_base": api_base}, ) - + ## COMPLETION CALL if "stream" in optional_params and optional_params["stream"] == True: response = requests.post( @@ -125,18 +150,20 @@ def completion( data=json.dumps(data), stream=optional_params["stream"], ) - + if response.status_code != 200: - raise AnthropicError(status_code=response.status_code, message=response.text) + raise AnthropicError( + status_code=response.status_code, message=response.text + ) return response.iter_lines() else: - response = requests.post( - api_base, headers=headers, data=json.dumps(data) - ) + response = requests.post(api_base, headers=headers, data=json.dumps(data)) if response.status_code != 200: - raise AnthropicError(status_code=response.status_code, message=response.text) - + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + ## LOGGING logging_obj.post_call( input=prompt, @@ -159,9 +186,9 @@ def completion( ) else: if len(completion_response["completion"]) > 0: - model_response["choices"][0]["message"]["content"] = completion_response[ - "completion" - ] + model_response["choices"][0]["message"][ + "content" + ] = completion_response["completion"] model_response.choices[0].finish_reason = completion_response["stop_reason"] ## CALCULATING USAGE @@ -177,11 +204,12 @@ def completion( usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage return model_response + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 2e75f7b40..123caafd0 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -1,7 +1,13 @@ from typing import Optional, Union, Any import types, requests from .base import BaseLLM -from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object +from litellm.utils import ( + ModelResponse, + Choices, + Message, + CustomStreamWrapper, + convert_to_model_response_object, +) from typing import Callable, Optional from litellm import OpenAIConfig import litellm, json @@ -9,8 +15,15 @@ import httpx from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from openai import AzureOpenAI, AsyncAzureOpenAI + class AzureOpenAIError(Exception): - def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None): + def __init__( + self, + status_code, + message, + request: Optional[httpx.Request] = None, + response: Optional[httpx.Response] = None, + ): self.status_code = status_code self.message = message if request: @@ -20,11 +33,14 @@ class AzureOpenAIError(Exception): if response: self.response = response else: - self.response = httpx.Response(status_code=status_code, request=self.request) + self.response = httpx.Response( + status_code=status_code, request=self.request + ) super().__init__( self.message ) # Call the base class constructor with the parameters it needs + class AzureOpenAIConfig(OpenAIConfig): """ Reference: https://platform.openai.com/docs/api-reference/chat/create @@ -49,33 +65,37 @@ class AzureOpenAIConfig(OpenAIConfig): - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. - - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. + - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. """ - def __init__(self, - frequency_penalty: Optional[int] = None, - function_call: Optional[Union[str, dict]]= None, - functions: Optional[list]= None, - logit_bias: Optional[dict]= None, - max_tokens: Optional[int]= None, - n: Optional[int]= None, - presence_penalty: Optional[int]= None, - stop: Optional[Union[str,list]]=None, - temperature: Optional[int]= None, - top_p: Optional[int]= None) -> None: - super().__init__(frequency_penalty, - function_call, - functions, - logit_bias, - max_tokens, - n, - presence_penalty, - stop, - temperature, - top_p) + def __init__( + self, + frequency_penalty: Optional[int] = None, + function_call: Optional[Union[str, dict]] = None, + functions: Optional[list] = None, + logit_bias: Optional[dict] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[int] = None, + stop: Optional[Union[str, list]] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + ) -> None: + super().__init__( + frequency_penalty, + function_call, + functions, + logit_bias, + max_tokens, + n, + presence_penalty, + stop, + temperature, + top_p, + ) + class AzureChatCompletion(BaseLLM): - def __init__(self) -> None: super().__init__() @@ -89,49 +109,51 @@ class AzureChatCompletion(BaseLLM): headers["Authorization"] = f"Bearer {azure_ad_token}" return headers - def completion(self, - model: str, - messages: list, - model_response: ModelResponse, - api_key: str, - api_base: str, - api_version: str, - api_type: str, - azure_ad_token: str, - print_verbose: Callable, - timeout, - logging_obj, - optional_params, - litellm_params, - logger_fn, - acompletion: bool = False, - headers: Optional[dict]=None, - client = None, - ): + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + api_key: str, + api_base: str, + api_version: str, + api_type: str, + azure_ad_token: str, + print_verbose: Callable, + timeout, + logging_obj, + optional_params, + litellm_params, + logger_fn, + acompletion: bool = False, + headers: Optional[dict] = None, + client=None, + ): super().completion() exception_mapping_worked = False try: - if model is None or messages is None: - raise AzureOpenAIError(status_code=422, message=f"Missing model or messages") - + raise AzureOpenAIError( + status_code=422, message=f"Missing model or messages" + ) + max_retries = optional_params.pop("max_retries", 2) ### CHECK IF CLOUDFLARE AI GATEWAY ### - ### if so - set the model as part of the base url - if "gateway.ai.cloudflare.com" in api_base: + ### if so - set the model as part of the base url + if "gateway.ai.cloudflare.com" in api_base: ## build base url - assume api base includes resource name if client is None: - if not api_base.endswith("/"): + if not api_base.endswith("/"): api_base += "/" api_base += f"{model}" - + azure_client_params = { "api_version": api_version, "base_url": f"{api_base}", "http_client": litellm.client_session, "max_retries": max_retries, - "timeout": timeout + "timeout": timeout, } if api_key is not None: azure_client_params["api_key"] = api_key @@ -142,26 +164,53 @@ class AzureChatCompletion(BaseLLM): client = AsyncAzureOpenAI(**azure_client_params) else: client = AzureOpenAI(**azure_client_params) - + + data = {"model": None, "messages": messages, **optional_params} + else: data = { - "model": None, - "messages": messages, - **optional_params + "model": model, # type: ignore + "messages": messages, + **optional_params, } - else: - data = { - "model": model, # type: ignore - "messages": messages, - **optional_params - } - - if acompletion is True: + + if acompletion is True: if optional_params.get("stream", False): - return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client) + return self.async_streaming( + logging_obj=logging_obj, + api_base=api_base, + data=data, + model=model, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + client=client, + ) else: - return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, client=client, logging_obj=logging_obj) + return self.acompletion( + api_base=api_base, + data=data, + model_response=model_response, + api_key=api_key, + api_version=api_version, + model=model, + azure_ad_token=azure_ad_token, + timeout=timeout, + client=client, + logging_obj=logging_obj, + ) elif "stream" in optional_params and optional_params["stream"] == True: - return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client) + return self.streaming( + logging_obj=logging_obj, + api_base=api_base, + data=data, + model=model, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + client=client, + ) else: ## LOGGING logging_obj.pre_call( @@ -169,16 +218,18 @@ class AzureChatCompletion(BaseLLM): api_key=api_key, additional_args={ "headers": { - "api_key": api_key, - "azure_ad_token": azure_ad_token + "api_key": api_key, + "azure_ad_token": azure_ad_token, }, "api_version": api_version, "api_base": api_base, "complete_input_dict": data, }, ) - if not isinstance(max_retries, int): - raise AzureOpenAIError(status_code=422, message="max retries must be an int") + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) # init AzureOpenAI Client azure_client_params = { "api_version": api_version, @@ -186,7 +237,7 @@ class AzureChatCompletion(BaseLLM): "azure_deployment": model, "http_client": litellm.client_session, "max_retries": max_retries, - "timeout": timeout + "timeout": timeout, } if api_key is not None: azure_client_params["api_key"] = api_key @@ -196,7 +247,7 @@ class AzureChatCompletion(BaseLLM): azure_client = AzureOpenAI(**azure_client_params) else: azure_client = client - response = azure_client.chat.completions.create(**data) # type: ignore + response = azure_client.chat.completions.create(**data) # type: ignore stringified_response = response.model_dump_json() ## LOGGING logging_obj.post_call( @@ -209,30 +260,36 @@ class AzureChatCompletion(BaseLLM): "api_base": api_base, }, ) - return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response) - except AzureOpenAIError as e: + return convert_to_model_response_object( + response_object=json.loads(stringified_response), + model_response_object=model_response, + ) + except AzureOpenAIError as e: exception_mapping_worked = True raise e - except Exception as e: + except Exception as e: raise e - - async def acompletion(self, - api_key: str, - api_version: str, - model: str, - api_base: str, - data: dict, - timeout: Any, - model_response: ModelResponse, - azure_ad_token: Optional[str]=None, - client = None, # this is the AsyncAzureOpenAI - logging_obj=None, - ): - response = None - try: + + async def acompletion( + self, + api_key: str, + api_version: str, + model: str, + api_base: str, + data: dict, + timeout: Any, + model_response: ModelResponse, + azure_ad_token: Optional[str] = None, + client=None, # this is the AsyncAzureOpenAI + logging_obj=None, + ): + response = None + try: max_retries = data.pop("max_retries", 2) - if not isinstance(max_retries, int): - raise AzureOpenAIError(status_code=422, message="max retries must be an int") + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) # init AzureOpenAI Client azure_client_params = { "api_version": api_version, @@ -240,7 +297,7 @@ class AzureChatCompletion(BaseLLM): "azure_deployment": model, "http_client": litellm.client_session, "max_retries": max_retries, - "timeout": timeout + "timeout": timeout, } if api_key is not None: azure_client_params["api_key"] = api_key @@ -252,35 +309,46 @@ class AzureChatCompletion(BaseLLM): azure_client = client ## LOGGING logging_obj.pre_call( - input=data['messages'], + input=data["messages"], api_key=azure_client.api_key, - additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, ) - response = await azure_client.chat.completions.create(**data) - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) - except AzureOpenAIError as e: + response = await azure_client.chat.completions.create(**data) + return convert_to_model_response_object( + response_object=json.loads(response.model_dump_json()), + model_response_object=model_response, + ) + except AzureOpenAIError as e: exception_mapping_worked = True raise e - except Exception as e: + except Exception as e: if hasattr(e, "status_code"): raise e else: raise AzureOpenAIError(status_code=500, message=str(e)) - def streaming(self, - logging_obj, - api_base: str, - api_key: str, - api_version: str, - data: dict, - model: str, - timeout: Any, - azure_ad_token: Optional[str]=None, - client=None, - ): + def streaming( + self, + logging_obj, + api_base: str, + api_key: str, + api_version: str, + data: dict, + model: str, + timeout: Any, + azure_ad_token: Optional[str] = None, + client=None, + ): max_retries = data.pop("max_retries", 2) - if not isinstance(max_retries, int): - raise AzureOpenAIError(status_code=422, message="max retries must be an int") + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) # init AzureOpenAI Client azure_client_params = { "api_version": api_version, @@ -288,7 +356,7 @@ class AzureChatCompletion(BaseLLM): "azure_deployment": model, "http_client": litellm.client_session, "max_retries": max_retries, - "timeout": timeout + "timeout": timeout, } if api_key is not None: azure_client_params["api_key"] = api_key @@ -300,25 +368,36 @@ class AzureChatCompletion(BaseLLM): azure_client = client ## LOGGING logging_obj.pre_call( - input=data['messages'], + input=data["messages"], api_key=azure_client.api_key, - additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, ) response = azure_client.chat.completions.create(**data) - streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="azure", + logging_obj=logging_obj, + ) return streamwrapper - async def async_streaming(self, - logging_obj, - api_base: str, - api_key: str, - api_version: str, - data: dict, - model: str, - timeout: Any, - azure_ad_token: Optional[str]=None, - client = None, - ): + async def async_streaming( + self, + logging_obj, + api_base: str, + api_key: str, + api_version: str, + data: dict, + model: str, + timeout: Any, + azure_ad_token: Optional[str] = None, + client=None, + ): # init AzureOpenAI Client azure_client_params = { "api_version": api_version, @@ -326,39 +405,49 @@ class AzureChatCompletion(BaseLLM): "azure_deployment": model, "http_client": litellm.client_session, "max_retries": data.pop("max_retries", 2), - "timeout": timeout + "timeout": timeout, } if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: azure_client_params["azure_ad_token"] = azure_ad_token if client is None: - azure_client = AsyncAzureOpenAI(**azure_client_params) + azure_client = AsyncAzureOpenAI(**azure_client_params) else: azure_client = client ## LOGGING logging_obj.pre_call( - input=data['messages'], + input=data["messages"], api_key=azure_client.api_key, - additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, ) response = await azure_client.chat.completions.create(**data) - streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="azure", + logging_obj=logging_obj, + ) async for transformed_chunk in streamwrapper: yield transformed_chunk async def aembedding( - self, - data: dict, - model_response: ModelResponse, + self, + data: dict, + model_response: ModelResponse, azure_client_params: dict, - api_key: str, - input: list, + api_key: str, + input: list, client=None, - logging_obj=None - ): + logging_obj=None, + ): response = None - try: + try: if client is None: openai_aclient = AsyncAzureOpenAI(**azure_client_params) else: @@ -367,50 +456,53 @@ class AzureChatCompletion(BaseLLM): stringified_response = response.model_dump_json() ## LOGGING logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=stringified_response, - ) - return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object( + response_object=json.loads(stringified_response), + model_response_object=model_response, + response_type="embedding", + ) except Exception as e: ## LOGGING logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=str(e), - ) + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) raise e - def embedding(self, - model: str, - input: list, - api_key: str, - api_base: str, - api_version: str, - timeout: float, - logging_obj=None, - model_response=None, - optional_params=None, - azure_ad_token: Optional[str]=None, - client = None, - aembedding=None, - ): + def embedding( + self, + model: str, + input: list, + api_key: str, + api_base: str, + api_version: str, + timeout: float, + logging_obj=None, + model_response=None, + optional_params=None, + azure_ad_token: Optional[str] = None, + client=None, + aembedding=None, + ): super().embedding() exception_mapping_worked = False if self._client_session is None: self._client_session = self.create_client_session() - try: - data = { - "model": model, - "input": input, - **optional_params - } + try: + data = {"model": model, "input": input, **optional_params} max_retries = data.pop("max_retries", 2) - if not isinstance(max_retries, int): - raise AzureOpenAIError(status_code=422, message="max retries must be an int") - + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) + # init AzureOpenAI Client azure_client_params = { "api_version": api_version, @@ -418,7 +510,7 @@ class AzureChatCompletion(BaseLLM): "azure_deployment": model, "http_client": litellm.client_session, "max_retries": max_retries, - "timeout": timeout + "timeout": timeout, } if api_key is not None: azure_client_params["api_key"] = api_key @@ -427,119 +519,130 @@ class AzureChatCompletion(BaseLLM): ## LOGGING logging_obj.pre_call( - input=input, - api_key=api_key, - additional_args={ - "complete_input_dict": data, - "headers": { - "api_key": api_key, - "azure_ad_token": azure_ad_token - } - }, - ) - + input=input, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": {"api_key": api_key, "azure_ad_token": azure_ad_token}, + }, + ) + if aembedding == True: - response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params) + response = self.aembedding( + data=data, + input=input, + logging_obj=logging_obj, + api_key=api_key, + model_response=model_response, + azure_client_params=azure_client_params, + ) return response if client is None: - azure_client = AzureOpenAI(**azure_client_params) # type: ignore + azure_client = AzureOpenAI(**azure_client_params) # type: ignore else: azure_client = client - ## COMPLETION CALL - response = azure_client.embeddings.create(**data) # type: ignore + ## COMPLETION CALL + response = azure_client.embeddings.create(**data) # type: ignore ## LOGGING logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data, "api_base": api_base}, - original_response=response, - ) + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data, "api_base": api_base}, + original_response=response, + ) - - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore - except AzureOpenAIError as e: + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore + except AzureOpenAIError as e: exception_mapping_worked = True raise e - except Exception as e: - if exception_mapping_worked: + except Exception as e: + if exception_mapping_worked: raise e - else: + else: import traceback + raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) async def aimage_generation( - self, - data: dict, - model_response: ModelResponse, + self, + data: dict, + model_response: ModelResponse, azure_client_params: dict, - api_key: str, - input: list, + api_key: str, + input: list, client=None, - logging_obj=None - ): + logging_obj=None, + ): response = None - try: + try: if client is None: - client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) - openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params) + client_session = litellm.aclient_session or httpx.AsyncClient( + transport=AsyncCustomHTTPTransport(), + ) + openai_aclient = AsyncAzureOpenAI( + http_client=client_session, **azure_client_params + ) else: openai_aclient = client response = await openai_aclient.images.generate(**data) stringified_response = response.model_dump_json() ## LOGGING logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=stringified_response, - ) - return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation") + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object( + response_object=json.loads(stringified_response), + model_response_object=model_response, + response_type="image_generation", + ) except Exception as e: ## LOGGING logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=str(e), - ) + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) raise e - - def image_generation(self, - prompt: str, - timeout: float, - model: Optional[str]=None, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - model_response: Optional[litellm.utils.ImageResponse] = None, - azure_ad_token: Optional[str]=None, - logging_obj=None, - optional_params=None, - client=None, - aimg_generation=None, - ): + + def image_generation( + self, + prompt: str, + timeout: float, + model: Optional[str] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + model_response: Optional[litellm.utils.ImageResponse] = None, + azure_ad_token: Optional[str] = None, + logging_obj=None, + optional_params=None, + client=None, + aimg_generation=None, + ): exception_mapping_worked = False - try: + try: if model and len(model) > 0: model = model else: model = None - data = { - "model": model, - "prompt": prompt, - **optional_params - } + data = {"model": model, "prompt": prompt, **optional_params} max_retries = data.pop("max_retries", 2) - if not isinstance(max_retries, int): - raise AzureOpenAIError(status_code=422, message="max retries must be an int") - + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) + # init AzureOpenAI Client azure_client_params = { "api_version": api_version, "azure_endpoint": api_base, "azure_deployment": model, "max_retries": max_retries, - "timeout": timeout + "timeout": timeout, } if api_key is not None: azure_client_params["api_key"] = api_key @@ -547,39 +650,47 @@ class AzureChatCompletion(BaseLLM): azure_client_params["azure_ad_token"] = azure_ad_token if aimg_generation == True: - response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore + response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore return response - + if client is None: - client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),) - azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore + client_session = litellm.client_session or httpx.Client( + transport=CustomHTTPTransport(), + ) + azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore else: azure_client = client - + ## LOGGING logging_obj.pre_call( input=prompt, api_key=azure_client.api_key, - additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data}, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": False, + "complete_input_dict": data, + }, ) - + ## COMPLETION CALL - response = azure_client.images.generate(**data) # type: ignore + response = azure_client.images.generate(**data) # type: ignore ## LOGGING logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=response, - ) + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) # return response - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore - except AzureOpenAIError as e: + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore + except AzureOpenAIError as e: exception_mapping_worked = True raise e - except Exception as e: - if exception_mapping_worked: + except Exception as e: + if exception_mapping_worked: raise e - else: + else: import traceback - raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) \ No newline at end of file + + raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) diff --git a/litellm/llms/base.py b/litellm/llms/base.py index a4c056e1a..5106153a5 100644 --- a/litellm/llms/base.py +++ b/litellm/llms/base.py @@ -1,47 +1,45 @@ ## This is a template base class to be used for adding new LLM providers via API calls -import litellm +import litellm import httpx, certifi, ssl from typing import Optional + class BaseLLM: _client_session: Optional[httpx.Client] = None + def create_client_session(self): - if litellm.client_session: + if litellm.client_session: _client_session = litellm.client_session - else: + else: _client_session = httpx.Client() - + return _client_session - + def create_aclient_session(self): - if litellm.aclient_session: + if litellm.aclient_session: _aclient_session = litellm.aclient_session - else: + else: _aclient_session = httpx.AsyncClient() - + return _aclient_session - + def __exit__(self): - if hasattr(self, '_client_session'): + if hasattr(self, "_client_session"): self._client_session.close() - + async def __aexit__(self, exc_type, exc_val, exc_tb): - if hasattr(self, '_aclient_session'): + if hasattr(self, "_aclient_session"): await self._aclient_session.aclose() def validate_environment(self): # set up the environment required to run the model pass def completion( - self, - *args, - **kwargs + self, *args, **kwargs ): # logic for parsing in - calling - parsing out model completion calls pass def embedding( - self, - *args, - **kwargs + self, *args, **kwargs ): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/baseten.py b/litellm/llms/baseten.py index b1e904c7a..b94491014 100644 --- a/litellm/llms/baseten.py +++ b/litellm/llms/baseten.py @@ -6,6 +6,7 @@ import time from typing import Callable from litellm.utils import ModelResponse, Usage + class BasetenError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -14,6 +15,7 @@ class BasetenError(Exception): self.message ) # Call the base class constructor with the parameters it needs + def validate_environment(api_key): headers = { "accept": "application/json", @@ -23,6 +25,7 @@ def validate_environment(api_key): headers["Authorization"] = f"Api-Key {api_key}" return headers + def completion( model: str, messages: list, @@ -52,32 +55,38 @@ def completion( "inputs": prompt, "prompt": prompt, "parameters": optional_params, - "stream": True if "stream" in optional_params and optional_params["stream"] == True else False + "stream": True + if "stream" in optional_params and optional_params["stream"] == True + else False, } ## LOGGING logging_obj.pre_call( - input=prompt, - api_key=api_key, - additional_args={"complete_input_dict": data}, - ) + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) ## COMPLETION CALL response = requests.post( completion_url_fragment_1 + model + completion_url_fragment_2, headers=headers, data=json.dumps(data), - stream=True if "stream" in optional_params and optional_params["stream"] == True else False + stream=True + if "stream" in optional_params and optional_params["stream"] == True + else False, ) - if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True): + if "text/event-stream" in response.headers["Content-Type"] or ( + "stream" in optional_params and optional_params["stream"] == True + ): return response.iter_lines() else: ## LOGGING logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT completion_response = response.json() @@ -91,9 +100,7 @@ def completion( if ( isinstance(completion_response["model_output"], dict) and "data" in completion_response["model_output"] - and isinstance( - completion_response["model_output"]["data"], list - ) + and isinstance(completion_response["model_output"]["data"], list) ): model_response["choices"][0]["message"][ "content" @@ -112,12 +119,19 @@ def completion( if "generated_text" not in completion_response: raise BasetenError( message=f"Unable to parse response. Original response: {response.text}", - status_code=response.status_code + status_code=response.status_code, ) - model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] - ## GETTING LOGPROBS - if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: - model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] + model_response["choices"][0]["message"][ + "content" + ] = completion_response[0]["generated_text"] + ## GETTING LOGPROBS + if ( + "details" in completion_response[0] + and "tokens" in completion_response[0]["details"] + ): + model_response.choices[0].finish_reason = completion_response[0][ + "details" + ]["finish_reason"] sum_logprob = 0 for token in completion_response[0]["details"]["tokens"]: sum_logprob += token["logprob"] @@ -125,7 +139,7 @@ def completion( else: raise BasetenError( message=f"Unable to parse response. Original response: {response.text}", - status_code=response.status_code + status_code=response.status_code, ) ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. @@ -139,11 +153,12 @@ def completion( usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage return model_response + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 561fb037d..56f6b9e7d 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, get_secret, Usage from .prompt_templates.factory import prompt_factory, custom_prompt import httpx + class BedrockError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock") + self.request = httpx.Request( + method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock" + ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class AmazonTitanConfig(): + +class AmazonTitanConfig: """ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1 @@ -29,29 +33,44 @@ class AmazonTitanConfig(): - `temperature` (float) temperature for model, - `topP` (int) top p for model """ - maxTokenCount: Optional[int]=None - stopSequences: Optional[list]=None - temperature: Optional[float]=None - topP: Optional[int]=None - def __init__(self, - maxTokenCount: Optional[int]=None, - stopSequences: Optional[list]=None, - temperature: Optional[float]=None, - topP: Optional[int]=None) -> None: + maxTokenCount: Optional[int] = None + stopSequences: Optional[list] = None + temperature: Optional[float] = None + topP: Optional[int] = None + + def __init__( + self, + maxTokenCount: Optional[int] = None, + stopSequences: Optional[list] = None, + temperature: Optional[float] = None, + topP: Optional[int] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } -class AmazonAnthropicConfig(): + +class AmazonAnthropicConfig: """ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude @@ -64,33 +83,48 @@ class AmazonAnthropicConfig(): - `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"], - `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" """ - max_tokens_to_sample: Optional[int]=litellm.max_tokens - stop_sequences: Optional[list]=None - temperature: Optional[float]=None - top_k: Optional[int]=None - top_p: Optional[int]=None - anthropic_version: Optional[str]=None - def __init__(self, - max_tokens_to_sample: Optional[int]=None, - stop_sequences: Optional[list]=None, - temperature: Optional[float]=None, - top_k: Optional[int]=None, - top_p: Optional[int]=None, - anthropic_version: Optional[str]=None) -> None: + max_tokens_to_sample: Optional[int] = litellm.max_tokens + stop_sequences: Optional[list] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[int] = None + anthropic_version: Optional[str] = None + + def __init__( + self, + max_tokens_to_sample: Optional[int] = None, + stop_sequences: Optional[list] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + anthropic_version: Optional[str] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } -class AmazonCohereConfig(): + +class AmazonCohereConfig: """ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command @@ -100,79 +134,110 @@ class AmazonCohereConfig(): - `temperature` (float) model temperature, - `return_likelihood` (string) n/a """ - max_tokens: Optional[int]=None - temperature: Optional[float]=None - return_likelihood: Optional[str]=None - def __init__(self, - max_tokens: Optional[int]=None, - temperature: Optional[float]=None, - return_likelihood: Optional[str]=None) -> None: + max_tokens: Optional[int] = None + temperature: Optional[float] = None + return_likelihood: Optional[str] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + return_likelihood: Optional[str] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } -class AmazonAI21Config(): + +class AmazonAI21Config: """ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra Supported Params for the Amazon / AI21 models: - + - `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`. - + - `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding. - + - `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass. - + - `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional. - + - `frequencyPenalty` (object): Placeholder for frequency penalty object. - + - `presencePenalty` (object): Placeholder for presence penalty object. - + - `countPenalty` (object): Placeholder for count penalty object. """ - maxTokens: Optional[int]=None - temperature: Optional[float]=None - topP: Optional[float]=None - stopSequences: Optional[list]=None - frequencePenalty: Optional[dict]=None - presencePenalty: Optional[dict]=None - countPenalty: Optional[dict]=None - def __init__(self, - maxTokens: Optional[int]=None, - temperature: Optional[float]=None, - topP: Optional[float]=None, - stopSequences: Optional[list]=None, - frequencePenalty: Optional[dict]=None, - presencePenalty: Optional[dict]=None, - countPenalty: Optional[dict]=None) -> None: + maxTokens: Optional[int] = None + temperature: Optional[float] = None + topP: Optional[float] = None + stopSequences: Optional[list] = None + frequencePenalty: Optional[dict] = None + presencePenalty: Optional[dict] = None + countPenalty: Optional[dict] = None + + def __init__( + self, + maxTokens: Optional[int] = None, + temperature: Optional[float] = None, + topP: Optional[float] = None, + stopSequences: Optional[list] = None, + frequencePenalty: Optional[dict] = None, + presencePenalty: Optional[dict] = None, + countPenalty: Optional[dict] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + class AnthropicConstants(Enum): HUMAN_PROMPT = "\n\nHuman: " AI_PROMPT = "\n\nAssistant: " -class AmazonLlamaConfig(): + +class AmazonLlamaConfig: """ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1 @@ -182,48 +247,72 @@ class AmazonLlamaConfig(): - `temperature` (float) temperature for model, - `top_p` (float) top p for model """ - max_gen_len: Optional[int]=None - temperature: Optional[float]=None - topP: Optional[float]=None - def __init__(self, - maxTokenCount: Optional[int]=None, - temperature: Optional[float]=None, - topP: Optional[int]=None) -> None: + max_gen_len: Optional[int] = None + temperature: Optional[float] = None + topP: Optional[float] = None + + def __init__( + self, + maxTokenCount: Optional[int] = None, + temperature: Optional[float] = None, + topP: Optional[int] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } def init_bedrock_client( - region_name = None, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_region_name: Optional[str] =None, - aws_bedrock_runtime_endpoint: Optional[str]=None, - ): + region_name=None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_bedrock_runtime_endpoint: Optional[str] = None, +): # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) standard_aws_region_name = get_secret("AWS_REGION", None) - - ## CHECK IS 'os.environ/' passed in - # Define the list of parameters to check - params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint] + + ## CHECK IS 'os.environ/' passed in + # Define the list of parameters to check + params_to_check = [ + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_bedrock_runtime_endpoint, + ] # Iterate over parameters and update if needed for i, param in enumerate(params_to_check): - if param and param.startswith('os.environ/'): + if param and param.startswith("os.environ/"): params_to_check[i] = get_secret(param) # Assign updated values back to parameters - aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check + ( + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_bedrock_runtime_endpoint, + ) = params_to_check if region_name: pass elif aws_region_name: @@ -233,7 +322,10 @@ def init_bedrock_client( elif standard_aws_region_name: region_name = standard_aws_region_name else: - raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", status_code=401) + raise BedrockError( + message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", + status_code=401, + ) # check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") @@ -242,9 +334,10 @@ def init_bedrock_client( elif env_aws_bedrock_runtime_endpoint: endpoint_url = env_aws_bedrock_runtime_endpoint else: - endpoint_url = f'https://bedrock-runtime.{region_name}.amazonaws.com' + endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com" import boto3 + if aws_access_key_id != None: # uses auth params passed to completion # aws_access_key_id is not None, assume user is trying to auth using litellm.completion @@ -257,7 +350,7 @@ def init_bedrock_client( endpoint_url=endpoint_url, ) else: - # aws_access_key_id is None, assume user is trying to auth using env variables + # aws_access_key_id is None, assume user is trying to auth using env variables # boto3 automatically reads env variables client = boto3.client( @@ -276,25 +369,23 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, ) else: - prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic") + prompt = prompt_factory( + model=model, messages=messages, custom_llm_provider="anthropic" + ) else: prompt = "" for message in messages: if "role" in message: if message["role"] == "user": - prompt += ( - f"{message['content']}" - ) + prompt += f"{message['content']}" else: - prompt += ( - f"{message['content']}" - ) + prompt += f"{message['content']}" else: prompt += f"{message['content']}" return prompt @@ -309,17 +400,18 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = "" # set os.environ['AWS_REGION_NAME'] = + def completion( - model: str, - messages: list, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - optional_params=None, - litellm_params=None, - logger_fn=None, + model: str, + messages: list, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params=None, + litellm_params=None, + logger_fn=None, ): exception_mapping_worked = False try: @@ -327,7 +419,9 @@ def completion( aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_region_name = optional_params.pop("aws_region_name", None) - aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # use passed in BedrockRuntime.Client if provided, otherwise create a new one client = optional_params.pop( @@ -343,67 +437,71 @@ def completion( model = model provider = model.split(".")[0] - prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) + prompt = convert_messages_to_prompt( + model, messages, provider, custom_prompt_dict + ) inference_params = copy.deepcopy(optional_params) stream = inference_params.pop("stream", False) if provider == "anthropic": ## LOAD CONFIG - config = litellm.AmazonAnthropicConfig.get_config() - for k, v in config.items(): - if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.AmazonAnthropicConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v - data = json.dumps({ - "prompt": prompt, - **inference_params - }) + data = json.dumps({"prompt": prompt, **inference_params}) elif provider == "ai21": ## LOAD CONFIG - config = litellm.AmazonAI21Config.get_config() - for k, v in config.items(): - if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.AmazonAI21Config.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v - data = json.dumps({ - "prompt": prompt, - **inference_params - }) + data = json.dumps({"prompt": prompt, **inference_params}) elif provider == "cohere": ## LOAD CONFIG - config = litellm.AmazonCohereConfig.get_config() - for k, v in config.items(): - if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.AmazonCohereConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v if optional_params.get("stream", False) == True: - inference_params["stream"] = True # cohere requires stream = True in inference params - data = json.dumps({ - "prompt": prompt, - **inference_params - }) + inference_params[ + "stream" + ] = True # cohere requires stream = True in inference params + data = json.dumps({"prompt": prompt, **inference_params}) elif provider == "meta": ## LOAD CONFIG config = litellm.AmazonLlamaConfig.get_config() - for k, v in config.items(): - if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v - data = json.dumps({ - "prompt": prompt, - **inference_params - }) + data = json.dumps({"prompt": prompt, **inference_params}) elif provider == "amazon": # amazon titan ## LOAD CONFIG - config = litellm.AmazonTitanConfig.get_config() - for k, v in config.items(): - if k not in inference_params: # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.AmazonTitanConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v - data = json.dumps({ - "inputText": prompt, - "textGenerationConfig": inference_params, - }) - + data = json.dumps( + { + "inputText": prompt, + "textGenerationConfig": inference_params, + } + ) + ## COMPLETION CALL - accept = 'application/json' - contentType = 'application/json' + accept = "application/json" + contentType = "application/json" if stream == True: if provider == "ai21": ## LOGGING @@ -418,17 +516,17 @@ def completion( logging_obj.pre_call( input=prompt, api_key="", - additional_args={"complete_input_dict": data, "request_str": request_str}, + additional_args={ + "complete_input_dict": data, + "request_str": request_str, + }, ) response = client.invoke_model( - body=data, - modelId=model, - accept=accept, - contentType=contentType + body=data, modelId=model, accept=accept, contentType=contentType ) - response = response.get('body').read() + response = response.get("body").read() return response else: ## LOGGING @@ -441,20 +539,20 @@ def completion( ) """ logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={"complete_input_dict": data, "request_str": request_str}, + input=prompt, + api_key="", + additional_args={ + "complete_input_dict": data, + "request_str": request_str, + }, ) - + response = client.invoke_model_with_response_stream( - body=data, - modelId=model, - accept=accept, - contentType=contentType + body=data, modelId=model, accept=accept, contentType=contentType ) - response = response.get('body') + response = response.get("body") return response - try: + try: ## LOGGING request_str = f""" response = client.invoke_model( @@ -465,20 +563,20 @@ def completion( ) """ logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={"complete_input_dict": data, "request_str": request_str}, - ) - response = client.invoke_model( - body=data, - modelId=model, - accept=accept, - contentType=contentType + input=prompt, + api_key="", + additional_args={ + "complete_input_dict": data, + "request_str": request_str, + }, ) - except Exception as e: + response = client.invoke_model( + body=data, modelId=model, accept=accept, contentType=contentType + ) + except Exception as e: raise BedrockError(status_code=500, message=str(e)) - - response_body = json.loads(response.get('body').read()) + + response_body = json.loads(response.get("body").read()) ## LOGGING logging_obj.post_call( @@ -491,16 +589,16 @@ def completion( ## RESPONSE OBJECT outputText = "default" if provider == "ai21": - outputText = response_body.get('completions')[0].get('data').get('text') + outputText = response_body.get("completions")[0].get("data").get("text") elif provider == "anthropic": - outputText = response_body['completion'] + outputText = response_body["completion"] model_response["finish_reason"] = response_body["stop_reason"] - elif provider == "cohere": + elif provider == "cohere": outputText = response_body["generations"][0]["text"] - elif provider == "meta": + elif provider == "meta": outputText = response_body["generation"] else: # amazon titan - outputText = response_body.get('results')[0].get('outputText') + outputText = response_body.get("results")[0].get("outputText") response_metadata = response.get("ResponseMetadata", {}) if response_metadata.get("HTTPStatusCode", 500) >= 400: @@ -513,12 +611,13 @@ def completion( if len(outputText) > 0: model_response["choices"][0]["message"]["content"] = outputText except: - raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500)) + raise BedrockError( + message=json.dumps(outputText), + status_code=response_metadata.get("HTTPStatusCode", 500), + ) - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len( - encoding.encode(prompt) - ) + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( encoding.encode(model_response["choices"][0]["message"].get("content", "")) ) @@ -528,41 +627,47 @@ def completion( usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens = prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage return model_response except BedrockError as e: exception_mapping_worked = True raise e - except Exception as e: + except Exception as e: if exception_mapping_worked: raise e - else: + else: import traceback + raise BedrockError(status_code=500, message=traceback.format_exc()) + def _embedding_func_single( - model: str, - input: str, - client: Any, - optional_params=None, - encoding=None, - logging_obj=None, + model: str, + input: str, + client: Any, + optional_params=None, + encoding=None, + logging_obj=None, ): # logic for parsing in - calling - parsing out model embedding calls - ## FORMAT EMBEDDING INPUT ## + ## FORMAT EMBEDDING INPUT ## provider = model.split(".")[0] inference_params = copy.deepcopy(optional_params) - inference_params.pop("user", None) # make sure user is not passed in for bedrock call + inference_params.pop( + "user", None + ) # make sure user is not passed in for bedrock call if provider == "amazon": input = input.replace(os.linesep, " ") data = {"inputText": input, **inference_params} # data = json.dumps(data) elif provider == "cohere": - inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3 - data = {"texts": [input], **inference_params} # type: ignore - body = json.dumps(data).encode("utf-8") + inference_params["input_type"] = inference_params.get( + "input_type", "search_document" + ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3 + data = {"texts": [input], **inference_params} # type: ignore + body = json.dumps(data).encode("utf-8") ## LOGGING request_str = f""" response = client.invoke_model( @@ -570,12 +675,14 @@ def _embedding_func_single( modelId={model}, accept="*/*", contentType="application/json", - )""" # type: ignore + )""" # type: ignore logging_obj.pre_call( input=input, - api_key="", # boto3 is used for init. - additional_args={"complete_input_dict": {"model": model, - "texts": input}, "request_str": request_str}, + api_key="", # boto3 is used for init. + additional_args={ + "complete_input_dict": {"model": model, "texts": input}, + "request_str": request_str, + }, ) try: response = client.invoke_model( @@ -587,11 +694,11 @@ def _embedding_func_single( response_body = json.loads(response.get("body").read()) ## LOGGING logging_obj.post_call( - input=input, - api_key="", - additional_args={"complete_input_dict": data}, - original_response=json.dumps(response_body), - ) + input=input, + api_key="", + additional_args={"complete_input_dict": data}, + original_response=json.dumps(response_body), + ) if provider == "cohere": response = response_body.get("embeddings") # flatten list @@ -600,7 +707,10 @@ def _embedding_func_single( elif provider == "amazon": return response_body.get("embedding") except Exception as e: - raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500) + raise BedrockError( + message=f"Embedding Error with model {model}: {e}", status_code=500 + ) + def embedding( model: str, @@ -616,7 +726,9 @@ def embedding( aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_region_name = optional_params.pop("aws_region_name", None) - aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # use passed in BedrockRuntime.Client if provided, otherwise create a new one client = init_bedrock_client( @@ -624,11 +736,19 @@ def embedding( aws_secret_access_key=aws_secret_access_key, aws_region_name=aws_region_name, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, - ) + ) ## Embedding Call - embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls - + embeddings = [ + _embedding_func_single( + model, + i, + optional_params=optional_params, + client=client, + logging_obj=logging_obj, + ) + for i in input + ] # [TODO]: make these parallel calls ## Populate OpenAI compliant dictionary embedding_response = [] @@ -647,13 +767,11 @@ def embedding( input_str = "".join(input) - input_tokens+=len(encoding.encode(input_str)) + input_tokens += len(encoding.encode(input_str)) usage = Usage( - prompt_tokens=input_tokens, - completion_tokens=0, - total_tokens=input_tokens + 0 + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + 0 ) model_response.usage = usage - + return model_response diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py index 8581731b5..40b65439b 100644 --- a/litellm/llms/cohere.py +++ b/litellm/llms/cohere.py @@ -8,88 +8,106 @@ from litellm.utils import ModelResponse, Choices, Message, Usage import litellm import httpx + class CohereError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/generate") + self.request = httpx.Request( + method="POST", url="https://api.cohere.ai/v1/generate" + ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class CohereConfig(): + +class CohereConfig: """ Reference: https://docs.cohere.com/reference/generate The class `CohereConfig` provides configuration for the Cohere's API interface. Below are the parameters: - + - `num_generations` (integer): Maximum number of generations returned. Default is 1, with a minimum value of 1 and a maximum value of 5. - + - `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default value is 20. - + - `truncate` (string): Specifies how the API handles inputs longer than maximum token length. Options include NONE, START, END. Default is END. - + - `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.75. - + - `preset` (string): Identifier of a custom preset, a combination of parameters such as prompt, temperature etc. - + - `end_sequences` (array of strings): The generated text gets cut at the beginning of the earliest occurrence of an end sequence, which will be excluded from the text. - + - `stop_sequences` (array of strings): The generated text gets cut at the end of the earliest occurrence of a stop sequence, which will be included in the text. - + - `k` (integer): Limits generation at each step to top `k` most likely tokens. Default is 0. - + - `p` (number): Limits generation at each step to most likely tokens with total probability mass of `p`. Default is 0. - + - `frequency_penalty` (number): Reduces repetitiveness of generated tokens. Higher values apply stronger penalties to previously occurred tokens. - + - `presence_penalty` (number): Reduces repetitiveness of generated tokens. Similar to frequency_penalty, but this penalty applies equally to all tokens that have already appeared. - + - `return_likelihoods` (string): Specifies how and if token likelihoods are returned with the response. Options include GENERATION, ALL and NONE. - + - `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233} """ - num_generations: Optional[int]=None - max_tokens: Optional[int]=None - truncate: Optional[str]=None - temperature: Optional[int]=None - preset: Optional[str]=None - end_sequences: Optional[list]=None - stop_sequences: Optional[list]=None - k: Optional[int]=None - p: Optional[int]=None - frequency_penalty: Optional[int]=None - presence_penalty: Optional[int]=None - return_likelihoods: Optional[str]=None - logit_bias: Optional[dict]=None - - def __init__(self, - num_generations: Optional[int]=None, - max_tokens: Optional[int]=None, - truncate: Optional[str]=None, - temperature: Optional[int]=None, - preset: Optional[str]=None, - end_sequences: Optional[list]=None, - stop_sequences: Optional[list]=None, - k: Optional[int]=None, - p: Optional[int]=None, - frequency_penalty: Optional[int]=None, - presence_penalty: Optional[int]=None, - return_likelihoods: Optional[str]=None, - logit_bias: Optional[dict]=None) -> None: - + + num_generations: Optional[int] = None + max_tokens: Optional[int] = None + truncate: Optional[str] = None + temperature: Optional[int] = None + preset: Optional[str] = None + end_sequences: Optional[list] = None + stop_sequences: Optional[list] = None + k: Optional[int] = None + p: Optional[int] = None + frequency_penalty: Optional[int] = None + presence_penalty: Optional[int] = None + return_likelihoods: Optional[str] = None + logit_bias: Optional[dict] = None + + def __init__( + self, + num_generations: Optional[int] = None, + max_tokens: Optional[int] = None, + truncate: Optional[str] = None, + temperature: Optional[int] = None, + preset: Optional[str] = None, + end_sequences: Optional[list] = None, + stop_sequences: Optional[list] = None, + k: Optional[int] = None, + p: Optional[int] = None, + frequency_penalty: Optional[int] = None, + presence_penalty: Optional[int] = None, + return_likelihoods: Optional[str] = None, + logit_bias: Optional[dict] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + def validate_environment(api_key): headers = { @@ -100,6 +118,7 @@ def validate_environment(api_key): headers["Authorization"] = f"Bearer {api_key}" return headers + def completion( model: str, messages: list, @@ -119,9 +138,11 @@ def completion( prompt = " ".join(message["content"] for message in messages) ## Load Config - config=litellm.CohereConfig.get_config() + config = litellm.CohereConfig.get_config() for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in + if ( + k not in optional_params + ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v data = { @@ -132,16 +153,23 @@ def completion( ## LOGGING logging_obj.pre_call( - input=prompt, - api_key=api_key, - additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url}, - ) + input=prompt, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": completion_url, + }, + ) ## COMPLETION CALL response = requests.post( - completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False + completion_url, + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"] if "stream" in optional_params else False, ) ## error handling for cohere calls - if response.status_code!=200: + if response.status_code != 200: raise CohereError(message=response.text, status_code=response.status_code) if "stream" in optional_params and optional_params["stream"] == True: @@ -149,11 +177,11 @@ def completion( else: ## LOGGING logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT completion_response = response.json() @@ -168,18 +196,22 @@ def completion( for idx, item in enumerate(completion_response["generations"]): if len(item["text"]) > 0: message_obj = Message(content=item["text"]) - else: + else: message_obj = Message(content=None) - choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) + choice_obj = Choices( + finish_reason=item["finish_reason"], + index=idx + 1, + message=message_obj, + ) choices_list.append(choice_obj) model_response["choices"] = choices_list except Exception as e: - raise CohereError(message=response.text, status_code=response.status_code) + raise CohereError( + message=response.text, status_code=response.status_code + ) ## CALCULATING USAGE - prompt_tokens = len( - encoding.encode(prompt) - ) + prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( encoding.encode(model_response["choices"][0]["message"].get("content", "")) ) @@ -189,11 +221,12 @@ def completion( usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage return model_response + def embedding( model: str, input: list, @@ -206,11 +239,7 @@ def embedding( headers = validate_environment(api_key) embed_url = "https://api.cohere.ai/v1/embed" model = model - data = { - "model": model, - "texts": input, - **optional_params - } + data = {"model": model, "texts": input, **optional_params} if "3" in model and "input_type" not in data: # cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document" @@ -218,21 +247,19 @@ def embedding( ## LOGGING logging_obj.pre_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - ) - ## COMPLETION CALL - response = requests.post( - embed_url, headers=headers, data=json.dumps(data) + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, ) + ## COMPLETION CALL + response = requests.post(embed_url, headers=headers, data=json.dumps(data)) ## LOGGING logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=response, - ) + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) """ response { @@ -244,30 +271,23 @@ def embedding( 'usage' } """ - if response.status_code!=200: + if response.status_code != 200: raise CohereError(message=response.text, status_code=response.status_code) - embeddings = response.json()['embeddings'] + embeddings = response.json()["embeddings"] output_data = [] for idx, embedding in enumerate(embeddings): output_data.append( - { - "object": "embedding", - "index": idx, - "embedding": embedding - } + {"object": "embedding", "index": idx, "embedding": embedding} ) model_response["object"] = "list" model_response["data"] = output_data model_response["model"] = model input_tokens = 0 for text in input: - input_tokens+=len(encoding.encode(text)) + input_tokens += len(encoding.encode(text)) - model_response["usage"] = { - "prompt_tokens": input_tokens, + model_response["usage"] = { + "prompt_tokens": input_tokens, "total_tokens": input_tokens, } return model_response - - - \ No newline at end of file diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py index 3bc50dda7..a62e1d666 100644 --- a/litellm/llms/custom_httpx/azure_dall_e_2.py +++ b/litellm/llms/custom_httpx/azure_dall_e_2.py @@ -1,20 +1,24 @@ import time, json, httpx, asyncio + class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): """ Async implementation of custom http transport """ - async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: if "images/generations" in request.url.path and request.url.params[ "api-version" - ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict "2023-06-01-preview", "2023-07-01-preview", "2023-08-01-preview", "2023-09-01-preview", - "2023-10-01-preview", + "2023-10-01-preview", ]: - request.url = request.url.copy_with(path="/openai/images/generations:submit") + request.url = request.url.copy_with( + path="/openai/images/generations:submit" + ) response = await super().handle_async_request(request) operation_location_url = response.headers["operation-location"] request.url = httpx.URL(operation_location_url) @@ -26,7 +30,12 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): start_time = time.time() while response.json()["status"] not in ["succeeded", "failed"]: if time.time() - start_time > timeout_secs: - timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} + timeout = { + "error": { + "code": "Timeout", + "message": "Operation polling timed out.", + } + } return httpx.Response( status_code=400, headers=response.headers, @@ -56,26 +65,30 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): ) return await super().handle_async_request(request) + class CustomHTTPTransport(httpx.HTTPTransport): """ This class was written as a workaround to support dall-e-2 on openai > v1.x Refer to this issue for more: https://github.com/openai/openai-python/issues/692 """ + def handle_request( self, request: httpx.Request, ) -> httpx.Response: if "images/generations" in request.url.path and request.url.params[ "api-version" - ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict "2023-06-01-preview", "2023-07-01-preview", "2023-08-01-preview", "2023-09-01-preview", - "2023-10-01-preview", + "2023-10-01-preview", ]: - request.url = request.url.copy_with(path="/openai/images/generations:submit") + request.url = request.url.copy_with( + path="/openai/images/generations:submit" + ) response = super().handle_request(request) operation_location_url = response.headers["operation-location"] request.url = httpx.URL(operation_location_url) @@ -87,7 +100,12 @@ class CustomHTTPTransport(httpx.HTTPTransport): start_time = time.time() while response.json()["status"] not in ["succeeded", "failed"]: if time.time() - start_time > timeout_secs: - timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} + timeout = { + "error": { + "code": "Timeout", + "message": "Operation polling timed out.", + } + } return httpx.Response( status_code=400, headers=response.headers, @@ -115,4 +133,4 @@ class CustomHTTPTransport(httpx.HTTPTransport): content=json.dumps(result).encode("utf-8"), request=request, ) - return super().handle_request(request) \ No newline at end of file + return super().handle_request(request) diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py index ebbad901a..6565faa04 100644 --- a/litellm/llms/gemini.py +++ b/litellm/llms/gemini.py @@ -8,17 +8,22 @@ import litellm import sys, httpx from .prompt_templates.factory import prompt_factory, custom_prompt + class GeminiError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat") + self.request = httpx.Request( + method="POST", + url="https://developers.generativeai.google/api/python/google/generativeai/chat", + ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class GeminiConfig(): + +class GeminiConfig: """ Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig @@ -37,33 +42,44 @@ class GeminiConfig(): - `top_k` (int): Optional. The maximum number of tokens to consider when sampling. """ - candidate_count: Optional[int]=None - stop_sequences: Optional[list]=None - max_output_tokens: Optional[int]=None - temperature: Optional[float]=None - top_p: Optional[float]=None - top_k: Optional[int]=None + candidate_count: Optional[int] = None + stop_sequences: Optional[list] = None + max_output_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None - def __init__(self, - candidate_count: Optional[int]=None, - stop_sequences: Optional[list]=None, - max_output_tokens: Optional[int]=None, - temperature: Optional[float]=None, - top_p: Optional[float]=None, - top_k: Optional[int]=None) -> None: - + def __init__( + self, + candidate_count: Optional[int] = None, + stop_sequences: Optional[list] = None, + max_output_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} - + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } def completion( @@ -83,42 +99,50 @@ def completion( try: import google.generativeai as genai except: - raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai") + raise Exception( + "Importing google.generativeai failed, please run 'pip install -q google-generativeai" + ) genai.configure(api_key=api_key) - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages - ) + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) else: - prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="gemini") + prompt = prompt_factory( + model=model, messages=messages, custom_llm_provider="gemini" + ) - ## Load Config inference_params = copy.deepcopy(optional_params) - inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py - config = litellm.GeminiConfig.get_config() - for k, v in config.items(): - if k not in inference_params: # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params.pop( + "stream", None + ) # palm does not support streaming, so we handle this by fake streaming in main.py + config = litellm.GeminiConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v - ## LOGGING logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={"complete_input_dict": {"inference_params": inference_params}}, - ) + input=prompt, + api_key="", + additional_args={"complete_input_dict": {"inference_params": inference_params}}, + ) ## COMPLETION CALL - try: - _model = genai.GenerativeModel(f'models/{model}') - response = _model.generate_content(contents=prompt, generation_config=genai.types.GenerationConfig(**inference_params)) + try: + _model = genai.GenerativeModel(f"models/{model}") + response = _model.generate_content( + contents=prompt, + generation_config=genai.types.GenerationConfig(**inference_params), + ) except Exception as e: raise GeminiError( message=str(e), @@ -127,11 +151,11 @@ def completion( ## LOGGING logging_obj.post_call( - input=prompt, - api_key="", - original_response=response, - additional_args={"complete_input_dict": {}}, - ) + input=prompt, + api_key="", + original_response=response, + additional_args={"complete_input_dict": {}}, + ) print_verbose(f"raw model_response: {response}") ## RESPONSE OBJECT completion_response = response @@ -142,31 +166,34 @@ def completion( message_obj = Message(content=item.content.parts[0].text) else: message_obj = Message(content=None) - choice_obj = Choices(index=idx+1, message=message_obj) + choice_obj = Choices(index=idx + 1, message=message_obj) choices_list.append(choice_obj) model_response["choices"] = choices_list except Exception as e: traceback.print_exc() - raise GeminiError(message=traceback.format_exc(), status_code=response.status_code) - - try: + raise GeminiError( + message=traceback.format_exc(), status_code=response.status_code + ) + + try: completion_response = model_response["choices"][0]["message"].get("content") except: - raise GeminiError(status_code=400, message=f"No response received. Original response - {response}") + raise GeminiError( + status_code=400, + message=f"No response received. Original response - {response}", + ) ## CALCULATING USAGE - prompt_str = "" + prompt_str = "" for m in messages: if isinstance(m["content"], str): prompt_str += m["content"] elif isinstance(m["content"], list): for content in m["content"]: - if content["type"] == "text": + if content["type"] == "text": prompt_str += content["text"] - prompt_tokens = len( - encoding.encode(prompt_str) - ) + prompt_tokens = len(encoding.encode(prompt_str)) completion_tokens = len( encoding.encode(model_response["choices"][0]["message"].get("content", "")) ) @@ -174,13 +201,14 @@ def completion( model_response["created"] = int(time.time()) model_response["model"] = "gemini/" + model usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens - ) + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) model_response.usage = usage return model_response + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 4dae6e88f..0cc8c5697 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -11,32 +11,47 @@ from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, from typing import Optional from .prompt_templates.factory import prompt_factory, custom_prompt + class HuggingfaceError(Exception): - def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None): + def __init__( + self, + status_code, + message, + request: Optional[httpx.Request] = None, + response: Optional[httpx.Response] = None, + ): self.status_code = status_code self.message = message if request is not None: self.request = request - else: - self.request = httpx.Request(method="POST", url="https://api-inference.huggingface.co/models") + else: + self.request = httpx.Request( + method="POST", url="https://api-inference.huggingface.co/models" + ) if response is not None: self.response = response - else: - self.response = httpx.Response(status_code=status_code, request=self.request) + else: + self.response = httpx.Response( + status_code=status_code, request=self.request + ) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class HuggingfaceConfig(): + +class HuggingfaceConfig: """ - Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate + Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate """ + best_of: Optional[int] = None decoder_input_details: Optional[bool] = None - details: Optional[bool] = True # enables returning logprobs + best of + details: Optional[bool] = True # enables returning logprobs + best of max_new_tokens: Optional[int] = None repetition_penalty: Optional[float] = None - return_full_text: Optional[bool] = False # by default don't return the input as part of the output + return_full_text: Optional[ + bool + ] = False # by default don't return the input as part of the output seed: Optional[int] = None temperature: Optional[float] = None top_k: Optional[int] = None @@ -46,50 +61,66 @@ class HuggingfaceConfig(): typical_p: Optional[float] = None watermark: Optional[bool] = None - def __init__(self, - best_of: Optional[int] = None, - decoder_input_details: Optional[bool] = None, - details: Optional[bool] = None, - max_new_tokens: Optional[int] = None, - repetition_penalty: Optional[float] = None, - return_full_text: Optional[bool] = None, - seed: Optional[int] = None, - temperature: Optional[float] = None, - top_k: Optional[int] = None, - top_n_tokens: Optional[int] = None, - top_p: Optional[int] = None, - truncate: Optional[int] = None, - typical_p: Optional[float] = None, - watermark: Optional[bool] = None - ) -> None: + def __init__( + self, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + details: Optional[bool] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[int] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } -def output_parser(generated_text: str): + +def output_parser(generated_text: str): """ - Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. + Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763 """ chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "", ""] - for token in chat_template_tokens: + for token in chat_template_tokens: if generated_text.strip().startswith(token): generated_text = generated_text.replace(token, "", 1) if generated_text.endswith(token): generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] return generated_text - + + tgi_models_cache = None conv_models_cache = None + + def read_tgi_conv_models(): try: global tgi_models_cache, conv_models_cache @@ -101,30 +132,38 @@ def read_tgi_conv_models(): tgi_models = set() script_directory = os.path.dirname(os.path.abspath(__file__)) # Construct the file path relative to the script's directory - file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_text_generation_models.txt") + file_path = os.path.join( + script_directory, + "huggingface_llms_metadata", + "hf_text_generation_models.txt", + ) - with open(file_path, 'r') as file: + with open(file_path, "r") as file: for line in file: tgi_models.add(line.strip()) - + # Cache the set for future use tgi_models_cache = tgi_models - + # If not, read the file and populate the cache - file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_conversational_models.txt") + file_path = os.path.join( + script_directory, + "huggingface_llms_metadata", + "hf_conversational_models.txt", + ) conv_models = set() - with open(file_path, 'r') as file: + with open(file_path, "r") as file: for line in file: conv_models.add(line.strip()) # Cache the set for future use - conv_models_cache = conv_models + conv_models_cache = conv_models return tgi_models, conv_models except: return set(), set() def get_hf_task_for_model(model): - # read text file, cast it to set + # read text file, cast it to set # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" tgi_models, conversational_models = read_tgi_conv_models() if model in tgi_models: @@ -134,9 +173,10 @@ def get_hf_task_for_model(model): elif "roneneldan/TinyStories" in model: return None else: - return "text-generation-inference" # default to tgi + return "text-generation-inference" # default to tgi -class Huggingface(BaseLLM): + +class Huggingface(BaseLLM): _client_session: Optional[httpx.Client] = None _aclient_session: Optional[httpx.AsyncClient] = None @@ -148,65 +188,93 @@ class Huggingface(BaseLLM): "content-type": "application/json", } if api_key and headers is None: - default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens + default_headers[ + "Authorization" + ] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens headers = default_headers elif headers: - headers=headers - else: + headers = headers + else: headers = default_headers return headers - def convert_to_model_response_object(self, - completion_response, - model_response, - task, - optional_params, - encoding, - input_text, - model): - if task == "conversational": - if len(completion_response["generated_text"]) > 0: # type: ignore + def convert_to_model_response_object( + self, + completion_response, + model_response, + task, + optional_params, + encoding, + input_text, + model, + ): + if task == "conversational": + if len(completion_response["generated_text"]) > 0: # type: ignore model_response["choices"][0]["message"][ "content" - ] = completion_response["generated_text"] # type: ignore - elif task == "text-generation-inference": - if (not isinstance(completion_response, list) + ] = completion_response[ + "generated_text" + ] # type: ignore + elif task == "text-generation-inference": + if ( + not isinstance(completion_response, list) or not isinstance(completion_response[0], dict) - or "generated_text" not in completion_response[0]): - raise HuggingfaceError(status_code=422, message=f"response is not in expected format - {completion_response}") + or "generated_text" not in completion_response[0] + ): + raise HuggingfaceError( + status_code=422, + message=f"response is not in expected format - {completion_response}", + ) - if len(completion_response[0]["generated_text"]) > 0: - model_response["choices"][0]["message"][ - "content" - ] = output_parser(completion_response[0]["generated_text"]) - ## GETTING LOGPROBS + FINISH REASON - if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: - model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] + if len(completion_response[0]["generated_text"]) > 0: + model_response["choices"][0]["message"]["content"] = output_parser( + completion_response[0]["generated_text"] + ) + ## GETTING LOGPROBS + FINISH REASON + if ( + "details" in completion_response[0] + and "tokens" in completion_response[0]["details"] + ): + model_response.choices[0].finish_reason = completion_response[0][ + "details" + ]["finish_reason"] sum_logprob = 0 for token in completion_response[0]["details"]["tokens"]: if token["logprob"] != None: sum_logprob += token["logprob"] model_response["choices"][0]["message"]._logprob = sum_logprob - if "best_of" in optional_params and optional_params["best_of"] > 1: - if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]: + if "best_of" in optional_params and optional_params["best_of"] > 1: + if ( + "details" in completion_response[0] + and "best_of_sequences" in completion_response[0]["details"] + ): choices_list = [] - for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]): + for idx, item in enumerate( + completion_response[0]["details"]["best_of_sequences"] + ): sum_logprob = 0 for token in item["tokens"]: if token["logprob"] != None: sum_logprob += token["logprob"] - if len(item["generated_text"]) > 0: - message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob) - else: + if len(item["generated_text"]) > 0: + message_obj = Message( + content=output_parser(item["generated_text"]), + logprobs=sum_logprob, + ) + else: message_obj = Message(content=None) - choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) + choice_obj = Choices( + finish_reason=item["finish_reason"], + index=idx + 1, + message=message_obj, + ) choices_list.append(choice_obj) model_response["choices"].extend(choices_list) else: - if len(completion_response[0]["generated_text"]) > 0: - model_response["choices"][0]["message"][ - "content" - ] = output_parser(completion_response[0]["generated_text"]) + if len(completion_response[0]["generated_text"]) > 0: + model_response["choices"][0]["message"]["content"] = output_parser( + completion_response[0]["generated_text"] + ) ## CALCULATING USAGE prompt_tokens = 0 try: @@ -221,12 +289,14 @@ class Huggingface(BaseLLM): completion_tokens = 0 try: completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) + encoding.encode( + model_response["choices"][0]["message"].get("content", "") + ) ) ##[TODO] use the llama2 tokenizer here except: # this should remain non blocking we should not block a response returning if calculating usage fails pass - else: + else: completion_tokens = 0 model_response["created"] = int(time.time()) @@ -234,13 +304,14 @@ class Huggingface(BaseLLM): usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage model_response._hidden_params["original_response"] = completion_response return model_response - def completion(self, + def completion( + self, model: str, messages: list, api_base: Optional[str], @@ -276,9 +347,11 @@ class Huggingface(BaseLLM): completion_url = f"https://api-inference.huggingface.co/models/{model}" ## Load Config - config=litellm.HuggingfaceConfig.get_config() + config = litellm.HuggingfaceConfig.get_config() for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in + if ( + k not in optional_params + ): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v ### MAP INPUT PARAMS @@ -298,11 +371,11 @@ class Huggingface(BaseLLM): generated_responses.append(message["content"]) data = { "inputs": { - "text": text, - "past_user_inputs": past_user_inputs, - "generated_responses": generated_responses + "text": text, + "past_user_inputs": past_user_inputs, + "generated_responses": generated_responses, }, - "parameters": inference_params + "parameters": inference_params, } input_text = "".join(message["content"] for message in messages) elif task == "text-generation-inference": @@ -311,29 +384,39 @@ class Huggingface(BaseLLM): # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", None), - initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - messages=messages + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get( + "initial_prompt_value", "" + ), + final_prompt_value=model_prompt_details.get( + "final_prompt_value", "" + ), + messages=messages, ) else: prompt = prompt_factory(model=model, messages=messages) data = { "inputs": prompt, "parameters": optional_params, - "stream": True if "stream" in optional_params and optional_params["stream"] == True else False, + "stream": True + if "stream" in optional_params and optional_params["stream"] == True + else False, } input_text = prompt else: # Non TGI and Conversational llms - # We need this branch, it removes 'details' and 'return_full_text' from params + # We need this branch, it removes 'details' and 'return_full_text' from params if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", {}), - initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + role_dict=model_prompt_details.get("roles", {}), + initial_prompt_value=model_prompt_details.get( + "initial_prompt_value", "" + ), + final_prompt_value=model_prompt_details.get( + "final_prompt_value", "" + ), bos_token=model_prompt_details.get("bos_token", ""), eos_token=model_prompt_details.get("eos_token", ""), messages=messages, @@ -346,52 +429,68 @@ class Huggingface(BaseLLM): data = { "inputs": prompt, "parameters": inference_params, - "stream": True if "stream" in optional_params and optional_params["stream"] == True else False, + "stream": True + if "stream" in optional_params and optional_params["stream"] == True + else False, } input_text = prompt ## LOGGING logging_obj.pre_call( - input=input_text, - api_key=api_key, - additional_args={"complete_input_dict": data, "task": task, "headers": headers, "api_base": completion_url, "acompletion": acompletion}, - ) + input=input_text, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "task": task, + "headers": headers, + "api_base": completion_url, + "acompletion": acompletion, + }, + ) ## COMPLETION CALL - if acompletion is True: - ### ASYNC STREAMING + if acompletion is True: + ### ASYNC STREAMING if optional_params.get("stream", False): - return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore + return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore else: ### ASYNC COMPLETION - return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore + return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore ### SYNC STREAMING if "stream" in optional_params and optional_params["stream"] == True: response = requests.post( - completion_url, - headers=headers, - data=json.dumps(data), - stream=optional_params["stream"] + completion_url, + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"], ) return response.iter_lines() ### SYNC COMPLETION else: response = requests.post( - completion_url, - headers=headers, - data=json.dumps(data) + completion_url, headers=headers, data=json.dumps(data) ) ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten) - is_streamed = False - if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream": + is_streamed = False + if ( + response.__dict__["headers"].get("Content-Type", "") + == "text/event-stream" + ): is_streamed = True - + # iterate over the complete streamed response, and return the final answer if is_streamed: - streamed_response = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="huggingface", logging_obj=logging_obj) + streamed_response = CustomStreamWrapper( + completion_stream=response.iter_lines(), + model=model, + custom_llm_provider="huggingface", + logging_obj=logging_obj, + ) content = "" - for chunk in streamed_response: + for chunk in streamed_response: content += chunk["choices"][0]["delta"]["content"] - completion_response: List[Dict[str, Any]] = [{"generated_text": content}] + completion_response: List[Dict[str, Any]] = [ + {"generated_text": content} + ] ## LOGGING logging_obj.post_call( input=input_text, @@ -399,7 +498,7 @@ class Huggingface(BaseLLM): original_response=completion_response, additional_args={"complete_input_dict": data, "task": task}, ) - else: + else: ## LOGGING logging_obj.post_call( input=input_text, @@ -410,15 +509,20 @@ class Huggingface(BaseLLM): ## RESPONSE OBJECT try: completion_response = response.json() - if isinstance(completion_response, dict): + if isinstance(completion_response, dict): completion_response = [completion_response] except: import traceback + raise HuggingfaceError( - message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}", status_code=response.status_code + message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}", + status_code=response.status_code, ) print_verbose(f"response: {completion_response}") - if isinstance(completion_response, dict) and "error" in completion_response: + if ( + isinstance(completion_response, dict) + and "error" in completion_response + ): print_verbose(f"completion error: {completion_response['error']}") print_verbose(f"response.status_code: {response.status_code}") raise HuggingfaceError( @@ -432,75 +536,98 @@ class Huggingface(BaseLLM): optional_params=optional_params, encoding=encoding, input_text=input_text, - model=model + model=model, ) - except HuggingfaceError as e: + except HuggingfaceError as e: exception_mapping_worked = True raise e - except Exception as e: - if exception_mapping_worked: + except Exception as e: + if exception_mapping_worked: raise e - else: + else: import traceback + raise HuggingfaceError(status_code=500, message=traceback.format_exc()) - async def acompletion(self, - api_base: str, - data: dict, - headers: dict, - model_response: ModelResponse, - task: str, - encoding: Any, - input_text: str, - model: str, - optional_params: dict): - response = None - try: + async def acompletion( + self, + api_base: str, + data: dict, + headers: dict, + model_response: ModelResponse, + task: str, + encoding: Any, + input_text: str, + model: str, + optional_params: dict, + ): + response = None + try: async with httpx.AsyncClient() as client: - response = await client.post(url=api_base, json=data, headers=headers, timeout=None) + response = await client.post( + url=api_base, json=data, headers=headers, timeout=None + ) response_json = response.json() if response.status_code != 200: - raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response) - + raise HuggingfaceError( + status_code=response.status_code, + message=response.text, + request=response.request, + response=response, + ) + ## RESPONSE OBJECT - return self.convert_to_model_response_object(completion_response=response_json, - model_response=model_response, - task=task, - encoding=encoding, - input_text=input_text, - model=model, - optional_params=optional_params) - except Exception as e: - if isinstance(e,httpx.TimeoutException): + return self.convert_to_model_response_object( + completion_response=response_json, + model_response=model_response, + task=task, + encoding=encoding, + input_text=input_text, + model=model, + optional_params=optional_params, + ) + except Exception as e: + if isinstance(e, httpx.TimeoutException): raise HuggingfaceError(status_code=500, message="Request Timeout Error") - elif response is not None and hasattr(response, "text"): - raise HuggingfaceError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}") - else: + elif response is not None and hasattr(response, "text"): + raise HuggingfaceError( + status_code=500, + message=f"{str(e)}\n\nOriginal Response: {response.text}", + ) + else: raise HuggingfaceError(status_code=500, message=f"{str(e)}") - async def async_streaming(self, - logging_obj, - api_base: str, - data: dict, - headers: dict, - model_response: ModelResponse, - model: str): + async def async_streaming( + self, + logging_obj, + api_base: str, + data: dict, + headers: dict, + model_response: ModelResponse, + model: str, + ): async with httpx.AsyncClient() as client: response = client.stream( - "POST", - url=f"{api_base}", - json=data, - headers=headers - ) - async with response as r: + "POST", url=f"{api_base}", json=data, headers=headers + ) + async with response as r: if r.status_code != 200: - raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming") - - streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj) + raise HuggingfaceError( + status_code=r.status_code, + message="An error occurred while streaming", + ) + + streamwrapper = CustomStreamWrapper( + completion_stream=r.aiter_lines(), + model=model, + custom_llm_provider="huggingface", + logging_obj=logging_obj, + ) async for transformed_chunk in streamwrapper: yield transformed_chunk - def embedding(self, + def embedding( + self, model: str, input: list, api_key: Optional[str] = None, @@ -523,65 +650,70 @@ class Huggingface(BaseLLM): embed_url = os.getenv("HUGGINGFACE_API_BASE", "") else: embed_url = f"https://api-inference.huggingface.co/models/{model}" - - if "sentence-transformers" in model: - if len(input) == 0: - raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences") + + if "sentence-transformers" in model: + if len(input) == 0: + raise HuggingfaceError( + status_code=400, + message="sentence transformers requires 2+ sentences", + ) data = { "inputs": { - "source_sentence": input[0], - "sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] + "source_sentence": input[0], + "sentences": [ + "That is a happy dog", + "That is a very happy person", + "Today is a sunny day", + ], } } else: - data = { - "inputs": input # type: ignore - } - + data = {"inputs": input} # type: ignore + ## LOGGING logging_obj.pre_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data, "headers": headers, "api_base": embed_url}, - ) - ## COMPLETION CALL - response = requests.post( - embed_url, headers=headers, data=json.dumps(data) + input=input, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": embed_url, + }, ) + ## COMPLETION CALL + response = requests.post(embed_url, headers=headers, data=json.dumps(data)) - ## LOGGING logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=response, - ) - + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) embeddings = response.json() - if "error" in embeddings: - raise HuggingfaceError(status_code=500, message=embeddings['error']) - + if "error" in embeddings: + raise HuggingfaceError(status_code=500, message=embeddings["error"]) + output_data = [] - if "similarities" in embeddings: + if "similarities" in embeddings: for idx, embedding in embeddings["similarities"]: output_data.append( - { - "object": "embedding", - "index": idx, - "embedding": embedding # flatten list returned from hf - } - ) - else: + { + "object": "embedding", + "index": idx, + "embedding": embedding, # flatten list returned from hf + } + ) + else: for idx, embedding in enumerate(embeddings): - if isinstance(embedding, float): + if isinstance(embedding, float): output_data.append( { "object": "embedding", "index": idx, - "embedding": embedding # flatten list returned from hf + "embedding": embedding, # flatten list returned from hf } ) elif isinstance(embedding, list) and isinstance(embedding[0], float): @@ -589,15 +721,17 @@ class Huggingface(BaseLLM): { "object": "embedding", "index": idx, - "embedding": embedding # flatten list returned from hf + "embedding": embedding, # flatten list returned from hf } ) - else: + else: output_data.append( { "object": "embedding", "index": idx, - "embedding": embedding[0][0] # flatten list returned from hf + "embedding": embedding[0][ + 0 + ], # flatten list returned from hf } ) model_response["object"] = "list" @@ -605,13 +739,10 @@ class Huggingface(BaseLLM): model_response["model"] = model input_tokens = 0 for text in input: - input_tokens+=len(encoding.encode(text)) + input_tokens += len(encoding.encode(text)) - model_response["usage"] = { - "prompt_tokens": input_tokens, + model_response["usage"] = { + "prompt_tokens": input_tokens, "total_tokens": input_tokens, } return model_response - - - diff --git a/litellm/llms/maritalk.py b/litellm/llms/maritalk.py index 68a3a4e32..77267b13d 100644 --- a/litellm/llms/maritalk.py +++ b/litellm/llms/maritalk.py @@ -7,6 +7,7 @@ from typing import Callable, Optional, List from litellm.utils import ModelResponse, Choices, Message, Usage import litellm + class MaritalkError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -15,24 +16,26 @@ class MaritalkError(Exception): self.message ) # Call the base class constructor with the parameters it needs -class MaritTalkConfig(): + +class MaritTalkConfig: """ The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters: - + - `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default is 1. - + - `model` (string): The model used for conversation. Default is 'maritalk'. - + - `do_sample` (boolean): If set to True, the API will generate a response using sampling. Default is True. - + - `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.7. - + - `top_p` (number): Selection threshold for token inclusion based on cumulative probability. Default is 0.95. - + - `repetition_penalty` (number): Penalty for repetition in the generated conversation. Default is 1. - + - `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped. """ + max_tokens: Optional[int] = None model: Optional[str] = None do_sample: Optional[bool] = None @@ -41,27 +44,40 @@ class MaritTalkConfig(): repetition_penalty: Optional[float] = None stopping_tokens: Optional[List[str]] = None - def __init__(self, - max_tokens: Optional[int]=None, - model: Optional[str] = None, - do_sample: Optional[bool] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - repetition_penalty: Optional[float] = None, - stopping_tokens: Optional[List[str]] = None) -> None: - + def __init__( + self, + max_tokens: Optional[int] = None, + model: Optional[str] = None, + do_sample: Optional[bool] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + stopping_tokens: Optional[List[str]] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} - + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def validate_environment(api_key): headers = { "accept": "application/json", @@ -71,6 +87,7 @@ def validate_environment(api_key): headers["Authorization"] = f"Key {api_key}" return headers + def completion( model: str, messages: list, @@ -89,9 +106,11 @@ def completion( model = model ## Load Config - config=litellm.MaritTalkConfig.get_config() + config = litellm.MaritTalkConfig.get_config() for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in + if ( + k not in optional_params + ): # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v data = { @@ -101,24 +120,27 @@ def completion( ## LOGGING logging_obj.pre_call( - input=messages, - api_key=api_key, - additional_args={"complete_input_dict": data}, - ) + input=messages, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) ## COMPLETION CALL response = requests.post( - completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False + completion_url, + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"] if "stream" in optional_params else False, ) if "stream" in optional_params and optional_params["stream"] == True: return response.iter_lines() else: ## LOGGING logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT completion_response = response.json() @@ -130,15 +152,17 @@ def completion( else: try: if len(completion_response["answer"]) > 0: - model_response["choices"][0]["message"]["content"] = completion_response["answer"] + model_response["choices"][0]["message"][ + "content" + ] = completion_response["answer"] except Exception as e: - raise MaritalkError(message=response.text, status_code=response.status_code) + raise MaritalkError( + message=response.text, status_code=response.status_code + ) ## CALCULATING USAGE prompt = "".join(m["content"] for m in messages) - prompt_tokens = len( - encoding.encode(prompt) - ) + prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( encoding.encode(model_response["choices"][0]["message"].get("content", "")) ) @@ -148,11 +172,12 @@ def completion( usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage return model_response + def embedding( model: str, input: list, @@ -161,4 +186,4 @@ def embedding( model_response=None, encoding=None, ): - pass \ No newline at end of file + pass diff --git a/litellm/llms/nlp_cloud.py b/litellm/llms/nlp_cloud.py index 8d09b85ea..f827975ce 100644 --- a/litellm/llms/nlp_cloud.py +++ b/litellm/llms/nlp_cloud.py @@ -7,6 +7,7 @@ from typing import Callable, Optional import litellm from litellm.utils import ModelResponse, Usage + class NLPCloudError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -15,7 +16,8 @@ class NLPCloudError(Exception): self.message ) # Call the base class constructor with the parameters it needs -class NLPCloudConfig(): + +class NLPCloudConfig: """ Reference: https://docs.nlpcloud.com/#generation @@ -43,45 +45,57 @@ class NLPCloudConfig(): - `num_return_sequences` (int): Optional. The number of independently computed returned sequences. """ - max_length: Optional[int]=None - length_no_input: Optional[bool]=None - end_sequence: Optional[str]=None - remove_end_sequence: Optional[bool]=None - remove_input: Optional[bool]=None - bad_words: Optional[list]=None - temperature: Optional[float]=None - top_p: Optional[float]=None - top_k: Optional[int]=None - repetition_penalty: Optional[float]=None - num_beams: Optional[int]=None - num_return_sequences: Optional[int]=None + max_length: Optional[int] = None + length_no_input: Optional[bool] = None + end_sequence: Optional[str] = None + remove_end_sequence: Optional[bool] = None + remove_input: Optional[bool] = None + bad_words: Optional[list] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + repetition_penalty: Optional[float] = None + num_beams: Optional[int] = None + num_return_sequences: Optional[int] = None - def __init__(self, - max_length: Optional[int]=None, - length_no_input: Optional[bool]=None, - end_sequence: Optional[str]=None, - remove_end_sequence: Optional[bool]=None, - remove_input: Optional[bool]=None, - bad_words: Optional[list]=None, - temperature: Optional[float]=None, - top_p: Optional[float]=None, - top_k: Optional[int]=None, - repetition_penalty: Optional[float]=None, - num_beams: Optional[int]=None, - num_return_sequences: Optional[int]=None) -> None: - + def __init__( + self, + max_length: Optional[int] = None, + length_no_input: Optional[bool] = None, + end_sequence: Optional[str] = None, + remove_end_sequence: Optional[bool] = None, + remove_input: Optional[bool] = None, + bad_words: Optional[list] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + repetition_penalty: Optional[float] = None, + num_beams: Optional[int] = None, + num_return_sequences: Optional[int] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } def validate_environment(api_key): @@ -93,6 +107,7 @@ def validate_environment(api_key): headers["Authorization"] = f"Token {api_key}" return headers + def completion( model: str, messages: list, @@ -110,9 +125,11 @@ def completion( headers = validate_environment(api_key) ## Load Config - config = litellm.NLPCloudConfig.get_config() - for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.NLPCloudConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v completion_url_fragment_1 = api_base @@ -129,24 +146,31 @@ def completion( ## LOGGING logging_obj.pre_call( - input=text, - api_key=api_key, - additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url}, - ) + input=text, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": completion_url, + }, + ) ## COMPLETION CALL response = requests.post( - completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False + completion_url, + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"] if "stream" in optional_params else False, ) if "stream" in optional_params and optional_params["stream"] == True: return clean_and_iterate_chunks(response) else: ## LOGGING logging_obj.post_call( - input=text, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) + input=text, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT try: @@ -161,11 +185,16 @@ def completion( else: try: if len(completion_response["generated_text"]) > 0: - model_response["choices"][0]["message"]["content"] = completion_response["generated_text"] + model_response["choices"][0]["message"][ + "content" + ] = completion_response["generated_text"] except: - raise NLPCloudError(message=json.dumps(completion_response), status_code=response.status_code) + raise NLPCloudError( + message=json.dumps(completion_response), + status_code=response.status_code, + ) - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. prompt_tokens = completion_response["nb_input_tokens"] completion_tokens = completion_response["nb_generated_tokens"] @@ -174,7 +203,7 @@ def completion( usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage return model_response @@ -187,25 +216,27 @@ def completion( # # Perform further processing based on your needs # return cleaned_chunk + # for line in response.iter_lines(): # if line: # yield process_chunk(line) def clean_and_iterate_chunks(response): - buffer = b'' + buffer = b"" for chunk in response.iter_content(chunk_size=1024): if not chunk: break buffer += chunk - while b'\x00' in buffer: - buffer = buffer.replace(b'\x00', b'') - yield buffer.decode('utf-8') - buffer = b'' + while b"\x00" in buffer: + buffer = buffer.replace(b"\x00", b"") + yield buffer.decode("utf-8") + buffer = b"" # No more data expected, yield any remaining data in the buffer if buffer: - yield buffer.decode('utf-8') + yield buffer.decode("utf-8") + def embedding(): # logic for parsing in - calling - parsing out model embedding calls diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 49e2f3d93..81e16a1a6 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -2,10 +2,11 @@ import requests, types, time import json, uuid import traceback from typing import Optional -import litellm +import litellm import httpx, aiohttp, asyncio from .prompt_templates.factory import prompt_factory, custom_prompt + class OllamaError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -16,14 +17,15 @@ class OllamaError(Exception): self.message ) # Call the base class constructor with the parameters it needs -class OllamaConfig(): + +class OllamaConfig: """ Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters: - + - `mirostat` (int): Enable Mirostat sampling for controlling perplexity. Default is 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0. Example usage: mirostat 0 - + - `mirostat_eta` (float): Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. Default: 0.1. Example usage: mirostat_eta 0.1 - `mirostat_tau` (float): Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. Default: 5.0. Example usage: mirostat_tau 5.0 @@ -56,102 +58,134 @@ class OllamaConfig(): - `template` (string): the full prompt or prompt template (overrides what is defined in the Modelfile) """ - mirostat: Optional[int]=None - mirostat_eta: Optional[float]=None - mirostat_tau: Optional[float]=None - num_ctx: Optional[int]=None - num_gqa: Optional[int]=None - num_thread: Optional[int]=None - repeat_last_n: Optional[int]=None - repeat_penalty: Optional[float]=None - temperature: Optional[float]=None - stop: Optional[list]=None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442 - tfs_z: Optional[float]=None - num_predict: Optional[int]=None - top_k: Optional[int]=None - top_p: Optional[float]=None - system: Optional[str]=None - template: Optional[str]=None - def __init__(self, - mirostat: Optional[int]=None, - mirostat_eta: Optional[float]=None, - mirostat_tau: Optional[float]=None, - num_ctx: Optional[int]=None, - num_gqa: Optional[int]=None, - num_thread: Optional[int]=None, - repeat_last_n: Optional[int]=None, - repeat_penalty: Optional[float]=None, - temperature: Optional[float]=None, - stop: Optional[list]=None, - tfs_z: Optional[float]=None, - num_predict: Optional[int]=None, - top_k: Optional[int]=None, - top_p: Optional[float]=None, - system: Optional[str]=None, - template: Optional[str]=None) -> None: + mirostat: Optional[int] = None + mirostat_eta: Optional[float] = None + mirostat_tau: Optional[float] = None + num_ctx: Optional[int] = None + num_gqa: Optional[int] = None + num_thread: Optional[int] = None + repeat_last_n: Optional[int] = None + repeat_penalty: Optional[float] = None + temperature: Optional[float] = None + stop: Optional[ + list + ] = None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442 + tfs_z: Optional[float] = None + num_predict: Optional[int] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + system: Optional[str] = None + template: Optional[str] = None + + def __init__( + self, + mirostat: Optional[int] = None, + mirostat_eta: Optional[float] = None, + mirostat_tau: Optional[float] = None, + num_ctx: Optional[int] = None, + num_gqa: Optional[int] = None, + num_thread: Optional[int] = None, + repeat_last_n: Optional[int] = None, + repeat_penalty: Optional[float] = None, + temperature: Optional[float] = None, + stop: Optional[list] = None, + tfs_z: Optional[float] = None, + num_predict: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + system: Optional[str] = None, + template: Optional[str] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + # ollama implementation def get_ollama_response( - api_base="http://localhost:11434", - model="llama2", - prompt="Why is the sky blue?", - optional_params=None, - logging_obj=None, - acompletion: bool = False, - model_response=None, - encoding=None - ): + api_base="http://localhost:11434", + model="llama2", + prompt="Why is the sky blue?", + optional_params=None, + logging_obj=None, + acompletion: bool = False, + model_response=None, + encoding=None, +): if api_base.endswith("/api/generate"): url = api_base - else: + else: url = f"{api_base}/api/generate" - + ## Load Config - config=litellm.OllamaConfig.get_config() + config = litellm.OllamaConfig.get_config() for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in + if ( + k not in optional_params + ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v optional_params["stream"] = optional_params.get("stream", False) - data = { - "model": model, - "prompt": prompt, - **optional_params - } + data = {"model": model, "prompt": prompt, **optional_params} ## LOGGING logging_obj.pre_call( input=None, api_key=None, - additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,}, + additional_args={ + "api_base": url, + "complete_input_dict": data, + "headers": {}, + "acompletion": acompletion, + }, ) - if acompletion is True: + if acompletion is True: if optional_params.get("stream", False) == True: - response = ollama_async_streaming(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj) + response = ollama_async_streaming( + url=url, + data=data, + model_response=model_response, + encoding=encoding, + logging_obj=logging_obj, + ) else: - response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj) + response = ollama_acompletion( + url=url, + data=data, + model_response=model_response, + encoding=encoding, + logging_obj=logging_obj, + ) return response elif optional_params.get("stream", False) == True: return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj) - + response = requests.post( - url=f"{url}", - json=data, - ) + url=f"{url}", + json=data, + ) if response.status_code != 200: - raise OllamaError(status_code=response.status_code, message=response.text) - + raise OllamaError(status_code=response.status_code, message=response.text) + ## LOGGING logging_obj.post_call( input=prompt, @@ -168,52 +202,76 @@ def get_ollama_response( ## RESPONSE OBJECT model_response["choices"][0]["finish_reason"] = "stop" if optional_params.get("format", "") == "json": - message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}]) + message = litellm.Message( + content=None, + tool_calls=[ + { + "id": f"call_{str(uuid.uuid4())}", + "function": {"arguments": response_json["response"], "name": ""}, + "type": "function", + } + ], + ) model_response["choices"][0]["message"] = message else: model_response["choices"][0]["message"]["content"] = response_json["response"] model_response["created"] = int(time.time()) model_response["model"] = "ollama/" + model - prompt_tokens = response_json["prompt_eval_count"] # type: ignore + prompt_tokens = response_json["prompt_eval_count"] # type: ignore completion_tokens = response_json["eval_count"] - model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) + model_response["usage"] = litellm.Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) return model_response + def ollama_completion_stream(url, data, logging_obj): with httpx.stream( - url=url, - json=data, - method="POST", - timeout=litellm.request_timeout - ) as response: - try: + url=url, json=data, method="POST", timeout=litellm.request_timeout + ) as response: + try: if response.status_code != 200: - raise OllamaError(status_code=response.status_code, message=response.text) - - streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj) + raise OllamaError( + status_code=response.status_code, message=response.text + ) + + streamwrapper = litellm.CustomStreamWrapper( + completion_stream=response.iter_lines(), + model=data["model"], + custom_llm_provider="ollama", + logging_obj=logging_obj, + ) for transformed_chunk in streamwrapper: yield transformed_chunk - except Exception as e: + except Exception as e: raise e + async def ollama_async_streaming(url, data, model_response, encoding, logging_obj): try: client = httpx.AsyncClient() async with client.stream( - url=f"{url}", - json=data, - method="POST", - timeout=litellm.request_timeout - ) as response: - if response.status_code != 200: - raise OllamaError(status_code=response.status_code, message=response.text) - - streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.aiter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj) - async for transformed_chunk in streamwrapper: - yield transformed_chunk + url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout + ) as response: + if response.status_code != 200: + raise OllamaError( + status_code=response.status_code, message=response.text + ) + + streamwrapper = litellm.CustomStreamWrapper( + completion_stream=response.aiter_lines(), + model=data["model"], + custom_llm_provider="ollama", + logging_obj=logging_obj, + ) + async for transformed_chunk in streamwrapper: + yield transformed_chunk except Exception as e: traceback.print_exc() + async def ollama_acompletion(url, data, model_response, encoding, logging_obj): data["stream"] = False try: @@ -224,10 +282,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj): if resp.status != 200: text = await resp.text() raise OllamaError(status_code=resp.status, message=text) - + ## LOGGING logging_obj.post_call( - input=data['prompt'], + input=data["prompt"], api_key="", original_response=resp.text, additional_args={ @@ -240,37 +298,59 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj): ## RESPONSE OBJECT model_response["choices"][0]["finish_reason"] = "stop" if data.get("format", "") == "json": - message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}]) + message = litellm.Message( + content=None, + tool_calls=[ + { + "id": f"call_{str(uuid.uuid4())}", + "function": { + "arguments": response_json["response"], + "name": "", + }, + "type": "function", + } + ], + ) model_response["choices"][0]["message"] = message else: - model_response["choices"][0]["message"]["content"] = response_json["response"] + model_response["choices"][0]["message"]["content"] = response_json[ + "response" + ] model_response["created"] = int(time.time()) - model_response["model"] = "ollama/" + data['model'] - prompt_tokens = response_json["prompt_eval_count"] # type: ignore + model_response["model"] = "ollama/" + data["model"] + prompt_tokens = response_json["prompt_eval_count"] # type: ignore completion_tokens = response_json["eval_count"] - model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) + model_response["usage"] = litellm.Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) return model_response except Exception as e: traceback.print_exc() raise e -async def ollama_aembeddings(api_base="http://localhost:11434", - model="llama2", - prompt="Why is the sky blue?", - optional_params=None, - logging_obj=None, - model_response=None, - encoding=None): +async def ollama_aembeddings( + api_base="http://localhost:11434", + model="llama2", + prompt="Why is the sky blue?", + optional_params=None, + logging_obj=None, + model_response=None, + encoding=None, +): if api_base.endswith("/api/embeddings"): url = api_base - else: + else: url = f"{api_base}/api/embeddings" - + ## Load Config - config=litellm.OllamaConfig.get_config() + config = litellm.OllamaConfig.get_config() for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in + if ( + k not in optional_params + ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v data = { @@ -290,7 +370,7 @@ async def ollama_aembeddings(api_base="http://localhost:11434", if response.status != 200: text = await response.text() raise OllamaError(status_code=response.status, message=text) - + ## LOGGING logging_obj.post_call( input=prompt, @@ -308,20 +388,16 @@ async def ollama_aembeddings(api_base="http://localhost:11434", output_data = [] for idx, embedding in enumerate(embeddings): output_data.append( - { - "object": "embedding", - "index": idx, - "embedding": embedding - } + {"object": "embedding", "index": idx, "embedding": embedding} ) model_response["object"] = "list" model_response["data"] = output_data model_response["model"] = model - input_tokens = len(encoding.encode(prompt)) + input_tokens = len(encoding.encode(prompt)) - model_response["usage"] = { - "prompt_tokens": input_tokens, + model_response["usage"] = { + "prompt_tokens": input_tokens, "total_tokens": input_tokens, } - return model_response \ No newline at end of file + return model_response diff --git a/litellm/llms/oobabooga.py b/litellm/llms/oobabooga.py index db0403c96..47d88cc79 100644 --- a/litellm/llms/oobabooga.py +++ b/litellm/llms/oobabooga.py @@ -7,6 +7,7 @@ from typing import Callable, Optional from litellm.utils import ModelResponse, Usage from .prompt_templates.factory import prompt_factory, custom_prompt + class OobaboogaError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -15,6 +16,7 @@ class OobaboogaError(Exception): self.message ) # Call the base class constructor with the parameters it needs + def validate_environment(api_key): headers = { "accept": "application/json", @@ -24,6 +26,7 @@ def validate_environment(api_key): headers["Authorization"] = f"Token {api_key}" return headers + def completion( model: str, messages: list, @@ -44,21 +47,24 @@ def completion( completion_url = model elif api_base: completion_url = api_base - else: - raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')") + else: + raise OobaboogaError( + status_code=404, + message="API Base not set. Set one via completion(..,api_base='your-api-url')", + ) model = model if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, ) else: prompt = prompt_factory(model=model, messages=messages) - + completion_url = completion_url + "/api/v1/generate" data = { "prompt": prompt, @@ -66,30 +72,35 @@ def completion( } ## LOGGING logging_obj.pre_call( - input=prompt, - api_key=api_key, - additional_args={"complete_input_dict": data}, - ) + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) ## COMPLETION CALL response = requests.post( - completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False + completion_url, + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"] if "stream" in optional_params else False, ) if "stream" in optional_params and optional_params["stream"] == True: return response.iter_lines() else: ## LOGGING logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT try: completion_response = response.json() except: - raise OobaboogaError(message=response.text, status_code=response.status_code) + raise OobaboogaError( + message=response.text, status_code=response.status_code + ) if "error" in completion_response: raise OobaboogaError( message=completion_response["error"], @@ -97,14 +108,17 @@ def completion( ) else: try: - model_response["choices"][0]["message"]["content"] = completion_response['results'][0]['text'] + model_response["choices"][0]["message"][ + "content" + ] = completion_response["results"][0]["text"] except: - raise OobaboogaError(message=json.dumps(completion_response), status_code=response.status_code) + raise OobaboogaError( + message=json.dumps(completion_response), + status_code=response.status_code, + ) ## CALCULATING USAGE - prompt_tokens = len( - encoding.encode(prompt) - ) + prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( encoding.encode(model_response["choices"][0]["message"]["content"]) ) @@ -114,11 +128,12 @@ def completion( usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage return model_response + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index e6b535295..5ef495631 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -2,15 +2,29 @@ from typing import Optional, Union, Any import types, time, json import httpx from .base import BaseLLM -from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object, Usage +from litellm.utils import ( + ModelResponse, + Choices, + Message, + CustomStreamWrapper, + convert_to_model_response_object, + Usage, +) from typing import Callable, Optional import aiohttp, requests import litellm from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI + class OpenAIError(Exception): - def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None): + def __init__( + self, + status_code, + message, + request: Optional[httpx.Request] = None, + response: Optional[httpx.Response] = None, + ): self.status_code = status_code self.message = message if request: @@ -20,13 +34,15 @@ class OpenAIError(Exception): if response: self.response = response else: - self.response = httpx.Response(status_code=status_code, request=self.request) + self.response = httpx.Response( + status_code=status_code, request=self.request + ) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class OpenAIConfig(): +class OpenAIConfig: """ Reference: https://platform.openai.com/docs/api-reference/chat/create @@ -50,44 +66,58 @@ class OpenAIConfig(): - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. - - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. + - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. """ - frequency_penalty: Optional[int]=None - function_call: Optional[Union[str, dict]]=None - functions: Optional[list]=None - logit_bias: Optional[dict]=None - max_tokens: Optional[int]=None - n: Optional[int]=None - presence_penalty: Optional[int]=None - stop: Optional[Union[str, list]]=None - temperature: Optional[int]=None - top_p: Optional[int]=None - def __init__(self, - frequency_penalty: Optional[int]=None, - function_call: Optional[Union[str, dict]]=None, - functions: Optional[list]=None, - logit_bias: Optional[dict]=None, - max_tokens: Optional[int]=None, - n: Optional[int]=None, - presence_penalty: Optional[int]=None, - stop: Optional[Union[str, list]]=None, - temperature: Optional[int]=None, - top_p: Optional[int]=None,) -> None: - + frequency_penalty: Optional[int] = None + function_call: Optional[Union[str, dict]] = None + functions: Optional[list] = None + logit_bias: Optional[dict] = None + max_tokens: Optional[int] = None + n: Optional[int] = None + presence_penalty: Optional[int] = None + stop: Optional[Union[str, list]] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + + def __init__( + self, + frequency_penalty: Optional[int] = None, + function_call: Optional[Union[str, dict]] = None, + functions: Optional[list] = None, + logit_bias: Optional[dict] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[int] = None, + stop: Optional[Union[str, list]] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } -class OpenAITextCompletionConfig(): + +class OpenAITextCompletionConfig: """ Reference: https://platform.openai.com/docs/api-reference/completions/create @@ -100,7 +130,7 @@ class OpenAITextCompletionConfig(): - `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line. - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. - + - `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens. - `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion. @@ -108,134 +138,199 @@ class OpenAITextCompletionConfig(): - `n` (integer or null): This optional parameter sets how many completions to generate for each prompt. - `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics. - + - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. - + - `suffix` (string or null): Defines the suffix that comes after a completion of inserted text. - `temperature` (number or null): This optional parameter defines the sampling temperature to use. - + - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. """ - best_of: Optional[int]=None - echo: Optional[bool]=None - frequency_penalty: Optional[int]=None - logit_bias: Optional[dict]=None - logprobs: Optional[int]=None - max_tokens: Optional[int]=None - n: Optional[int]=None - presence_penalty: Optional[int]=None - stop: Optional[Union[str, list]]=None - suffix: Optional[str]=None - temperature: Optional[float]=None - top_p: Optional[float]=None - def __init__(self, - best_of: Optional[int]=None, - echo: Optional[bool]=None, - frequency_penalty: Optional[int]=None, - logit_bias: Optional[dict]=None, - logprobs: Optional[int]=None, - max_tokens: Optional[int]=None, - n: Optional[int]=None, - presence_penalty: Optional[int]=None, - stop: Optional[Union[str, list]]=None, - suffix: Optional[str]=None, - temperature: Optional[float]=None, - top_p: Optional[float]=None) -> None: + best_of: Optional[int] = None + echo: Optional[bool] = None + frequency_penalty: Optional[int] = None + logit_bias: Optional[dict] = None + logprobs: Optional[int] = None + max_tokens: Optional[int] = None + n: Optional[int] = None + presence_penalty: Optional[int] = None + stop: Optional[Union[str, list]] = None + suffix: Optional[str] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + + def __init__( + self, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[int] = None, + logit_bias: Optional[dict] = None, + logprobs: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[int] = None, + stop: Optional[Union[str, list]] = None, + suffix: Optional[str] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + class OpenAIChatCompletion(BaseLLM): - def __init__(self) -> None: super().__init__() - def completion(self, - model_response: ModelResponse, - timeout: float, - model: Optional[str]=None, - messages: Optional[list]=None, - print_verbose: Optional[Callable]=None, - api_key: Optional[str]=None, - api_base: Optional[str]=None, - acompletion: bool = False, - logging_obj=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - headers: Optional[dict]=None, - custom_prompt_dict: dict={}, - client=None - ): + def completion( + self, + model_response: ModelResponse, + timeout: float, + model: Optional[str] = None, + messages: Optional[list] = None, + print_verbose: Optional[Callable] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + acompletion: bool = False, + logging_obj=None, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers: Optional[dict] = None, + custom_prompt_dict: dict = {}, + client=None, + ): super().completion() exception_mapping_worked = False - try: - if headers: + try: + if headers: optional_params["extra_headers"] = headers if model is None or messages is None: raise OpenAIError(status_code=422, message=f"Missing model or messages") - - if not isinstance(timeout, float): - raise OpenAIError(status_code=422, message=f"Timeout needs to be a float") - for _ in range(2): # if call fails due to alternating messages, retry with reformatted message - data = { - "model": model, - "messages": messages, - **optional_params - } - - try: + if not isinstance(timeout, float): + raise OpenAIError( + status_code=422, message=f"Timeout needs to be a float" + ) + + for _ in range( + 2 + ): # if call fails due to alternating messages, retry with reformatted message + data = {"model": model, "messages": messages, **optional_params} + + try: max_retries = data.pop("max_retries", 2) - if acompletion is True: + if acompletion is True: if optional_params.get("stream", False): - return self.async_streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) + return self.async_streaming( + logging_obj=logging_obj, + headers=headers, + data=data, + model=model, + api_base=api_base, + api_key=api_key, + timeout=timeout, + client=client, + max_retries=max_retries, + ) else: - return self.acompletion(data=data, headers=headers, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) + return self.acompletion( + data=data, + headers=headers, + logging_obj=logging_obj, + model_response=model_response, + api_base=api_base, + api_key=api_key, + timeout=timeout, + client=client, + max_retries=max_retries, + ) elif optional_params.get("stream", False): - return self.streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) + return self.streaming( + logging_obj=logging_obj, + headers=headers, + data=data, + model=model, + api_base=api_base, + api_key=api_key, + timeout=timeout, + client=client, + max_retries=max_retries, + ) else: ## LOGGING logging_obj.pre_call( input=messages, api_key=api_key, - additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data}, + additional_args={ + "headers": headers, + "api_base": api_base, + "acompletion": acompletion, + "complete_input_dict": data, + }, ) - if not isinstance(max_retries, int): - raise OpenAIError(status_code=422, message="max retries must be an int") + if not isinstance(max_retries, int): + raise OpenAIError( + status_code=422, message="max retries must be an int" + ) if client is None: - openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) + openai_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, + ) else: openai_client = client - response = openai_client.chat.completions.create(**data) # type: ignore + response = openai_client.chat.completions.create(**data) # type: ignore stringified_response = response.model_dump_json() logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=stringified_response, - additional_args={"complete_input_dict": data}, - ) - return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response) + input=messages, + api_key=api_key, + original_response=stringified_response, + additional_args={"complete_input_dict": data}, + ) + return convert_to_model_response_object( + response_object=json.loads(stringified_response), + model_response_object=model_response, + ) except Exception as e: - if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e): + if "Conversation roles must alternate user/assistant" in str( + e + ) or "user and assistant roles should be alternating" in str(e): # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility new_messages = [] - for i in range(len(messages)-1): + for i in range(len(messages) - 1): new_messages.append(messages[i]) - if messages[i]["role"] == messages[i+1]["role"]: + if messages[i]["role"] == messages[i + 1]["role"]: if messages[i]["role"] == "user": - new_messages.append({"role": "assistant", "content": ""}) + new_messages.append( + {"role": "assistant", "content": ""} + ) else: new_messages.append({"role": "user", "content": ""}) new_messages.append(messages[-1]) @@ -246,136 +341,197 @@ class OpenAIChatCompletion(BaseLLM): messages = new_messages else: raise e - except OpenAIError as e: + except OpenAIError as e: exception_mapping_worked = True raise e - except Exception as e: + except Exception as e: raise e - - async def acompletion(self, - data: dict, - model_response: ModelResponse, - timeout: float, - api_key: Optional[str]=None, - api_base: Optional[str]=None, - client=None, - max_retries=None, - logging_obj=None, - headers=None - ): + + async def acompletion( + self, + data: dict, + model_response: ModelResponse, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + max_retries=None, + logging_obj=None, + headers=None, + ): response = None - try: + try: if client is None: - openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) + openai_aclient = AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.aclient_session, + timeout=timeout, + max_retries=max_retries, + ) else: openai_aclient = client ## LOGGING logging_obj.pre_call( - input=data['messages'], + input=data["messages"], api_key=openai_aclient.api_key, - additional_args={"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, "api_base": openai_aclient._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, + additional_args={ + "headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, + "api_base": openai_aclient._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, ) response = await openai_aclient.chat.completions.create(**data) stringified_response = response.model_dump_json() logging_obj.post_call( - input=data['messages'], - api_key=api_key, - original_response=stringified_response, - additional_args={"complete_input_dict": data}, - ) - return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response) + input=data["messages"], + api_key=api_key, + original_response=stringified_response, + additional_args={"complete_input_dict": data}, + ) + return convert_to_model_response_object( + response_object=json.loads(stringified_response), + model_response_object=model_response, + ) except Exception as e: raise e - def streaming(self, - logging_obj, - timeout: float, - data: dict, - model: str, - api_key: Optional[str]=None, - api_base: Optional[str]=None, - client = None, - max_retries=None, - headers=None + def streaming( + self, + logging_obj, + timeout: float, + data: dict, + model: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + max_retries=None, + headers=None, ): if client is None: - openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) + openai_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, + ) else: openai_client = client ## LOGGING logging_obj.pre_call( - input=data['messages'], + input=data["messages"], api_key=api_key, - additional_args={"headers": headers, "api_base": api_base, "acompletion": False, "complete_input_dict": data}, + additional_args={ + "headers": headers, + "api_base": api_base, + "acompletion": False, + "complete_input_dict": data, + }, ) response = openai_client.chat.completions.create(**data) - streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="openai", + logging_obj=logging_obj, + ) return streamwrapper - async def async_streaming(self, - logging_obj, - timeout: float, - data: dict, - model: str, - api_key: Optional[str]=None, - api_base: Optional[str]=None, - client=None, - max_retries=None, - headers=None - ): + async def async_streaming( + self, + logging_obj, + timeout: float, + data: dict, + model: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + max_retries=None, + headers=None, + ): response = None - try: + try: if client is None: - openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) + openai_aclient = AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.aclient_session, + timeout=timeout, + max_retries=max_retries, + ) else: openai_aclient = client ## LOGGING logging_obj.pre_call( - input=data['messages'], + input=data["messages"], api_key=api_key, - additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data}, + additional_args={ + "headers": headers, + "api_base": api_base, + "acompletion": True, + "complete_input_dict": data, + }, ) response = await openai_aclient.chat.completions.create(**data) - streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="openai", + logging_obj=logging_obj, + ) async for transformed_chunk in streamwrapper: yield transformed_chunk - except Exception as e: # need to exception handle here. async exceptions don't get caught in sync functions. + except ( + Exception + ) as e: # need to exception handle here. async exceptions don't get caught in sync functions. if response is not None and hasattr(response, "text"): - raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}") + raise OpenAIError( + status_code=500, + message=f"{str(e)}\n\nOriginal Response: {response.text}", + ) else: - if type(e).__name__ == "ReadTimeout": + if type(e).__name__ == "ReadTimeout": raise OpenAIError(status_code=408, message=f"{type(e).__name__}") else: raise OpenAIError(status_code=500, message=f"{str(e)}") + async def aembedding( - self, - input: list, - data: dict, - model_response: ModelResponse, - timeout: float, - api_key: Optional[str]=None, - api_base: Optional[str]=None, - client=None, - max_retries=None, - logging_obj=None - ): + self, + input: list, + data: dict, + model_response: ModelResponse, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + max_retries=None, + logging_obj=None, + ): response = None - try: + try: if client is None: - openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) + openai_aclient = AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.aclient_session, + timeout=timeout, + max_retries=max_retries, + ) else: openai_aclient = client - response = await openai_aclient.embeddings.create(**data) # type: ignore + response = await openai_aclient.embeddings.create(**data) # type: ignore stringified_response = response.model_dump_json() ## LOGGING logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=stringified_response, - ) - return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore except Exception as e: ## LOGGING logging_obj.post_call( @@ -384,95 +540,105 @@ class OpenAIChatCompletion(BaseLLM): original_response=str(e), ) raise e - - def embedding(self, - model: str, - input: list, - timeout: float, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - model_response: Optional[litellm.utils.EmbeddingResponse] = None, - logging_obj=None, - optional_params=None, - client=None, - aembedding=None, - ): + + def embedding( + self, + model: str, + input: list, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + model_response: Optional[litellm.utils.EmbeddingResponse] = None, + logging_obj=None, + optional_params=None, + client=None, + aembedding=None, + ): super().embedding() exception_mapping_worked = False - try: + try: model = model - data = { - "model": model, - "input": input, - **optional_params - } + data = {"model": model, "input": input, **optional_params} max_retries = data.pop("max_retries", 2) - if not isinstance(max_retries, int): + if not isinstance(max_retries, int): raise OpenAIError(status_code=422, message="max retries must be an int") ## LOGGING logging_obj.pre_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data, "api_base": api_base}, - ) - + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data, "api_base": api_base}, + ) + if aembedding == True: - response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore + response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore return response if client is None: - openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) + openai_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, + ) else: openai_client = client - + ## COMPLETION CALL - response = openai_client.embeddings.create(**data) # type: ignore + response = openai_client.embeddings.create(**data) # type: ignore ## LOGGING logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=response, - ) - - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore - except OpenAIError as e: + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore + except OpenAIError as e: exception_mapping_worked = True raise e - except Exception as e: - if exception_mapping_worked: + except Exception as e: + if exception_mapping_worked: raise e - else: + else: import traceback + raise OpenAIError(status_code=500, message=traceback.format_exc()) async def aimage_generation( - self, - prompt: str, - data: dict, - model_response: ModelResponse, - timeout: float, - api_key: Optional[str]=None, - api_base: Optional[str]=None, - client=None, - max_retries=None, - logging_obj=None - ): + self, + prompt: str, + data: dict, + model_response: ModelResponse, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + max_retries=None, + logging_obj=None, + ): response = None - try: + try: if client is None: - openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) + openai_aclient = AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.aclient_session, + timeout=timeout, + max_retries=max_retries, + ) else: openai_aclient = client - response = await openai_aclient.images.generate(**data) # type: ignore + response = await openai_aclient.images.generate(**data) # type: ignore stringified_response = response.model_dump_json() ## LOGGING logging_obj.post_call( - input=prompt, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=stringified_response, - ) - return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore except Exception as e: ## LOGGING logging_obj.post_call( @@ -482,74 +648,84 @@ class OpenAIChatCompletion(BaseLLM): ) raise e - def image_generation(self, - model: Optional[str], - prompt: str, - timeout: float, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - model_response: Optional[litellm.utils.ImageResponse] = None, - logging_obj=None, - optional_params=None, - client=None, - aimg_generation=None, - ): + def image_generation( + self, + model: Optional[str], + prompt: str, + timeout: float, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + model_response: Optional[litellm.utils.ImageResponse] = None, + logging_obj=None, + optional_params=None, + client=None, + aimg_generation=None, + ): exception_mapping_worked = False - try: + try: model = model - data = { - "model": model, - "prompt": prompt, - **optional_params - } + data = {"model": model, "prompt": prompt, **optional_params} max_retries = data.pop("max_retries", 2) - if not isinstance(max_retries, int): + if not isinstance(max_retries, int): raise OpenAIError(status_code=422, message="max retries must be an int") - + # if aembedding == True: # response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore # return response - + if client is None: - openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) + openai_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, + ) else: openai_client = client - + ## LOGGING logging_obj.pre_call( input=prompt, api_key=openai_client.api_key, - additional_args={"headers": {"Authorization": f"Bearer {openai_client.api_key}"}, "api_base": openai_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, + additional_args={ + "headers": {"Authorization": f"Bearer {openai_client.api_key}"}, + "api_base": openai_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, ) - + ## COMPLETION CALL - response = openai_client.images.generate(**data) # type: ignore + response = openai_client.images.generate(**data) # type: ignore ## LOGGING logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=response, - ) + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) # return response - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore - except OpenAIError as e: + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore + except OpenAIError as e: exception_mapping_worked = True raise e - except Exception as e: - if exception_mapping_worked: + except Exception as e: + if exception_mapping_worked: raise e - else: + else: import traceback + raise OpenAIError(status_code=500, message=traceback.format_exc()) + class OpenAITextCompletion(BaseLLM): _client_session: httpx.Client def __init__(self) -> None: super().__init__() self._client_session = self.create_client_session() - + def validate_environment(self, api_key): headers = { "content-type": "application/json", @@ -557,82 +733,110 @@ class OpenAITextCompletion(BaseLLM): if api_key: headers["Authorization"] = f"Bearer {api_key}" return headers - - def convert_to_model_response_object(self, response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None): - try: + + def convert_to_model_response_object( + self, + response_object: Optional[dict] = None, + model_response_object: Optional[ModelResponse] = None, + ): + try: ## RESPONSE OBJECT if response_object is None or model_response_object is None: raise ValueError("Error in response object format") - choice_list=[] - for idx, choice in enumerate(response_object["choices"]): + choice_list = [] + for idx, choice in enumerate(response_object["choices"]): message = Message(content=choice["text"], role="assistant") - choice = Choices(finish_reason=choice["finish_reason"], index=idx, message=message) + choice = Choices( + finish_reason=choice["finish_reason"], index=idx, message=message + ) choice_list.append(choice) model_response_object.choices = choice_list - if "usage" in response_object: + if "usage" in response_object: model_response_object.usage = response_object["usage"] - - if "id" in response_object: + + if "id" in response_object: model_response_object.id = response_object["id"] - - if "model" in response_object: + + if "model" in response_object: model_response_object.model = response_object["model"] - - model_response_object._hidden_params["original_response"] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response + + model_response_object._hidden_params[ + "original_response" + ] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response return model_response_object - except Exception as e: + except Exception as e: raise e - def completion(self, - model_response: ModelResponse, - api_key: str, - model: str, - messages: list, - print_verbose: Optional[Callable]=None, - api_base: Optional[str]=None, - logging_obj=None, - acompletion: bool = False, - optional_params=None, - litellm_params=None, - logger_fn=None, - headers: Optional[dict]=None): + def completion( + self, + model_response: ModelResponse, + api_key: str, + model: str, + messages: list, + print_verbose: Optional[Callable] = None, + api_base: Optional[str] = None, + logging_obj=None, + acompletion: bool = False, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers: Optional[dict] = None, + ): super().completion() exception_mapping_worked = False - try: + try: if headers is None: headers = self.validate_environment(api_key=api_key) if model is None or messages is None: raise OpenAIError(status_code=422, message=f"Missing model or messages") - + api_base = f"{api_base}/completions" - if len(messages)>0 and "content" in messages[0] and type(messages[0]["content"]) == list: + if ( + len(messages) > 0 + and "content" in messages[0] + and type(messages[0]["content"]) == list + ): prompt = messages[0]["content"] else: - prompt = " ".join([message["content"] for message in messages]) # type: ignore + prompt = " ".join([message["content"] for message in messages]) # type: ignore # don't send max retries to the api, if set optional_params.pop("max_retries", None) - data = { - "model": model, - "prompt": prompt, - **optional_params - } + data = {"model": model, "prompt": prompt, **optional_params} ## LOGGING logging_obj.pre_call( input=messages, api_key=api_key, - additional_args={"headers": headers, "api_base": api_base, "complete_input_dict": data}, + additional_args={ + "headers": headers, + "api_base": api_base, + "complete_input_dict": data, + }, ) - if acompletion == True: + if acompletion == True: if optional_params.get("stream", False): - return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) + return self.async_streaming( + logging_obj=logging_obj, + api_base=api_base, + data=data, + headers=headers, + model_response=model_response, + model=model, + ) else: - return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore + return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore elif optional_params.get("stream", False): - return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) + return self.streaming( + logging_obj=logging_obj, + api_base=api_base, + data=data, + headers=headers, + model_response=model_response, + model=model, + ) else: response = httpx.post( url=f"{api_base}", @@ -640,8 +844,10 @@ class OpenAITextCompletion(BaseLLM): headers=headers, ) if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text) - + raise OpenAIError( + status_code=response.status_code, message=response.text + ) + ## LOGGING logging_obj.post_call( input=prompt, @@ -654,26 +860,38 @@ class OpenAITextCompletion(BaseLLM): ) ## RESPONSE OBJECT - return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response) - except Exception as e: + return self.convert_to_model_response_object( + response_object=response.json(), + model_response_object=model_response, + ) + except Exception as e: raise e - - async def acompletion(self, - logging_obj, - api_base: str, - data: dict, - headers: dict, - model_response: ModelResponse, - prompt: str, - api_key: str, - model: str): + + async def acompletion( + self, + logging_obj, + api_base: str, + data: dict, + headers: dict, + model_response: ModelResponse, + prompt: str, + api_key: str, + model: str, + ): async with httpx.AsyncClient() as client: - try: - response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) + try: + response = await client.post( + api_base, + json=data, + headers=headers, + timeout=litellm.request_timeout, + ) response_json = response.json() if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text) - + raise OpenAIError( + status_code=response.status_code, message=response.text + ) + ## LOGGING logging_obj.post_call( input=prompt, @@ -686,53 +904,72 @@ class OpenAITextCompletion(BaseLLM): ) ## RESPONSE OBJECT - return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response) - except Exception as e: + return self.convert_to_model_response_object( + response_object=response_json, model_response_object=model_response + ) + except Exception as e: raise e - def streaming(self, - logging_obj, - api_base: str, - data: dict, - headers: dict, - model_response: ModelResponse, - model: str + def streaming( + self, + logging_obj, + api_base: str, + data: dict, + headers: dict, + model_response: ModelResponse, + model: str, ): with httpx.stream( - url=f"{api_base}", - json=data, - headers=headers, - method="POST", - timeout=litellm.request_timeout - ) as response: - if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text) - - streamwrapper = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj) - for transformed_chunk in streamwrapper: - yield transformed_chunk + url=f"{api_base}", + json=data, + headers=headers, + method="POST", + timeout=litellm.request_timeout, + ) as response: + if response.status_code != 200: + raise OpenAIError( + status_code=response.status_code, message=response.text + ) - async def async_streaming(self, - logging_obj, - api_base: str, - data: dict, - headers: dict, - model_response: ModelResponse, - model: str): + streamwrapper = CustomStreamWrapper( + completion_stream=response.iter_lines(), + model=model, + custom_llm_provider="text-completion-openai", + logging_obj=logging_obj, + ) + for transformed_chunk in streamwrapper: + yield transformed_chunk + + async def async_streaming( + self, + logging_obj, + api_base: str, + data: dict, + headers: dict, + model_response: ModelResponse, + model: str, + ): client = httpx.AsyncClient() async with client.stream( - url=f"{api_base}", - json=data, - headers=headers, - method="POST", - timeout=litellm.request_timeout - ) as response: - try: + url=f"{api_base}", + json=data, + headers=headers, + method="POST", + timeout=litellm.request_timeout, + ) as response: + try: if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text) - - streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj) + raise OpenAIError( + status_code=response.status_code, message=response.text + ) + + streamwrapper = CustomStreamWrapper( + completion_stream=response.aiter_lines(), + model=model, + custom_llm_provider="text-completion-openai", + logging_obj=logging_obj, + ) async for transformed_chunk in streamwrapper: yield transformed_chunk - except Exception as e: - raise e \ No newline at end of file + except Exception as e: + raise e diff --git a/litellm/llms/openrouter.py b/litellm/llms/openrouter.py index fa21ab61e..b6ec4024f 100644 --- a/litellm/llms/openrouter.py +++ b/litellm/llms/openrouter.py @@ -1,30 +1,41 @@ from typing import List, Dict import types -class OpenrouterConfig(): + +class OpenrouterConfig: """ Reference: https://openrouter.ai/docs#format """ + # OpenRouter-only parameters - extra_body: Dict[str, List[str]] = { - 'transforms': [] # default transforms to [] - } + extra_body: Dict[str, List[str]] = {"transforms": []} # default transforms to [] - - def __init__(self, - transforms: List[str] = [], - models: List[str] = [], - route: str = '', - ) -> None: + def __init__( + self, + transforms: List[str] = [], + models: List[str] = [], + route: str = "", + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} \ No newline at end of file + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } diff --git a/litellm/llms/palm.py b/litellm/llms/palm.py index 010e6720c..060e6dca1 100644 --- a/litellm/llms/palm.py +++ b/litellm/llms/palm.py @@ -7,17 +7,22 @@ from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage import litellm import sys, httpx + class PalmError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat") + self.request = httpx.Request( + method="POST", + url="https://developers.generativeai.google/api/python/google/generativeai/chat", + ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class PalmConfig(): + +class PalmConfig: """ Reference: https://developers.generativeai.google/api/python/google/generativeai/chat @@ -37,35 +42,47 @@ class PalmConfig(): - `max_output_tokens` (int): Sets the maximum number of tokens to be returned in the output """ - context: Optional[str]=None - examples: Optional[list]=None - temperature: Optional[float]=None - candidate_count: Optional[int]=None - top_k: Optional[int]=None - top_p: Optional[float]=None - max_output_tokens: Optional[int]=None - def __init__(self, - context: Optional[str]=None, - examples: Optional[list]=None, - temperature: Optional[float]=None, - candidate_count: Optional[int]=None, - top_k: Optional[int]=None, - top_p: Optional[float]=None, - max_output_tokens: Optional[int]=None) -> None: - + context: Optional[str] = None + examples: Optional[list] = None + temperature: Optional[float] = None + candidate_count: Optional[int] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + max_output_tokens: Optional[int] = None + + def __init__( + self, + context: Optional[str] = None, + examples: Optional[list] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + max_output_tokens: Optional[int] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} - + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } def completion( @@ -83,41 +100,43 @@ def completion( try: import google.generativeai as palm except: - raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai") + raise Exception( + "Importing google.generativeai failed, please run 'pip install -q google-generativeai" + ) palm.configure(api_key=api_key) model = model - + ## Load Config inference_params = copy.deepcopy(optional_params) - inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py - config = litellm.PalmConfig.get_config() - for k, v in config.items(): - if k not in inference_params: # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params.pop( + "stream", None + ) # palm does not support streaming, so we handle this by fake streaming in main.py + config = litellm.PalmConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v prompt = "" for message in messages: if "role" in message: if message["role"] == "user": - prompt += ( - f"{message['content']}" - ) + prompt += f"{message['content']}" else: - prompt += ( - f"{message['content']}" - ) + prompt += f"{message['content']}" else: prompt += f"{message['content']}" - + ## LOGGING logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={"complete_input_dict": {"inference_params": inference_params}}, - ) + input=prompt, + api_key="", + additional_args={"complete_input_dict": {"inference_params": inference_params}}, + ) ## COMPLETION CALL - try: + try: response = palm.generate_text(prompt=prompt, **inference_params) except Exception as e: raise PalmError( @@ -127,11 +146,11 @@ def completion( ## LOGGING logging_obj.post_call( - input=prompt, - api_key="", - original_response=response, - additional_args={"complete_input_dict": {}}, - ) + input=prompt, + api_key="", + original_response=response, + additional_args={"complete_input_dict": {}}, + ) print_verbose(f"raw model_response: {response}") ## RESPONSE OBJECT completion_response = response @@ -142,22 +161,25 @@ def completion( message_obj = Message(content=item["output"]) else: message_obj = Message(content=None) - choice_obj = Choices(index=idx+1, message=message_obj) + choice_obj = Choices(index=idx + 1, message=message_obj) choices_list.append(choice_obj) model_response["choices"] = choices_list except Exception as e: traceback.print_exc() - raise PalmError(message=traceback.format_exc(), status_code=response.status_code) - - try: + raise PalmError( + message=traceback.format_exc(), status_code=response.status_code + ) + + try: completion_response = model_response["choices"][0]["message"].get("content") except: - raise PalmError(status_code=400, message=f"No response received. Original response - {response}") + raise PalmError( + status_code=400, + message=f"No response received. Original response - {response}", + ) - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len( - encoding.encode(prompt) - ) + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( encoding.encode(model_response["choices"][0]["message"].get("content", "")) ) @@ -165,13 +187,14 @@ def completion( model_response["created"] = int(time.time()) model_response["model"] = "palm/" + model usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens - ) + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) model_response.usage = usage return model_response + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/petals.py b/litellm/llms/petals.py index f9ce3ad0c..bc30306a6 100644 --- a/litellm/llms/petals.py +++ b/litellm/llms/petals.py @@ -8,6 +8,7 @@ import litellm from litellm.utils import ModelResponse, Usage from .prompt_templates.factory import prompt_factory, custom_prompt + class PetalsError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -16,7 +17,8 @@ class PetalsError(Exception): self.message ) # Call the base class constructor with the parameters it needs -class PetalsConfig(): + +class PetalsConfig: """ Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below: @@ -30,45 +32,64 @@ class PetalsConfig(): - `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below: - `temperature` (float, optional): This value sets the temperature for sampling. - + - `top_k` (integer, optional): This value sets the limit for top-k sampling. - + - `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling. - + - `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper. """ - max_length: Optional[int]=None - max_new_tokens: Optional[int]=litellm.max_tokens # petals requires max tokens to be set - do_sample: Optional[bool]=None - temperature: Optional[float]=None - top_k: Optional[int]=None - top_p: Optional[float]=None - repetition_penalty: Optional[float]=None - def __init__(self, - max_length: Optional[int]=None, - max_new_tokens: Optional[int]=litellm.max_tokens, # petals requires max tokens to be set - do_sample: Optional[bool]=None, - temperature: Optional[float]=None, - top_k: Optional[int]=None, - top_p: Optional[float]=None, - repetition_penalty: Optional[float]=None) -> None: + max_length: Optional[int] = None + max_new_tokens: Optional[ + int + ] = litellm.max_tokens # petals requires max tokens to be set + do_sample: Optional[bool] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + repetition_penalty: Optional[float] = None + + def __init__( + self, + max_length: Optional[int] = None, + max_new_tokens: Optional[ + int + ] = litellm.max_tokens, # petals requires max tokens to be set + do_sample: Optional[bool] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} - + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def completion( model: str, messages: list, - api_base: Optional[str], + api_base: Optional[str], model_response: ModelResponse, print_verbose: Callable, encoding, @@ -80,96 +101,97 @@ def completion( ): ## Load Config config = litellm.PetalsConfig.get_config() - for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v if model in litellm.custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = litellm.custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, ) else: prompt = prompt_factory(model=model, messages=messages) - if api_base: + if api_base: ## LOGGING logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={"complete_input_dict": optional_params, "api_base": api_base}, - ) - data = { - "model": model, - "inputs": prompt, - **optional_params - } - + input=prompt, + api_key="", + additional_args={ + "complete_input_dict": optional_params, + "api_base": api_base, + }, + ) + data = {"model": model, "inputs": prompt, **optional_params} + ## COMPLETION CALL response = requests.post(api_base, data=data) - + ## LOGGING logging_obj.post_call( - input=prompt, - api_key="", - original_response=response.text, - additional_args={"complete_input_dict": optional_params}, - ) - + input=prompt, + api_key="", + original_response=response.text, + additional_args={"complete_input_dict": optional_params}, + ) + ## RESPONSE OBJECT try: output_text = response.json()["outputs"] except Exception as e: PetalsError(status_code=response.status_code, message=str(e)) - else: + else: try: import torch from transformers import AutoTokenizer - from petals import AutoDistributedModelForCausalLM # type: ignore + from petals import AutoDistributedModelForCausalLM # type: ignore except: raise Exception( "Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals" ) - + model = model - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, add_bos_token=False) + tokenizer = AutoTokenizer.from_pretrained( + model, use_fast=False, add_bos_token=False + ) model_obj = AutoDistributedModelForCausalLM.from_pretrained(model) ## LOGGING logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={"complete_input_dict": optional_params}, - ) - + input=prompt, + api_key="", + additional_args={"complete_input_dict": optional_params}, + ) + ## COMPLETION CALL inputs = tokenizer(prompt, return_tensors="pt")["input_ids"] - + # optional params: max_new_tokens=1,temperature=0.9, top_p=0.6 outputs = model_obj.generate(inputs, **optional_params) ## LOGGING logging_obj.post_call( - input=prompt, - api_key="", - original_response=outputs, - additional_args={"complete_input_dict": optional_params}, - ) + input=prompt, + api_key="", + original_response=outputs, + additional_args={"complete_input_dict": optional_params}, + ) ## RESPONSE OBJECT output_text = tokenizer.decode(outputs[0]) - + if len(output_text) > 0: model_response["choices"][0]["message"]["content"] = output_text - prompt_tokens = len( - encoding.encode(prompt) - ) + prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( encoding.encode(model_response["choices"][0]["message"].get("content")) ) @@ -177,13 +199,14 @@ def completion( model_response["created"] = int(time.time()) model_response["model"] = model usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens - ) + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) model_response.usage = usage return model_response + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 265cae941..57d30a404 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -4,11 +4,13 @@ import json from jinja2 import Template, exceptions, Environment, meta from typing import Optional, Any + def default_pt(messages): return " ".join(message["content"] for message in messages) -# alpaca prompt template - for models like mythomax, etc. -def alpaca_pt(messages): + +# alpaca prompt template - for models like mythomax, etc. +def alpaca_pt(messages): prompt = custom_prompt( role_dict={ "system": { @@ -19,59 +21,56 @@ def alpaca_pt(messages): "pre_message": "### Instruction:\n", "post_message": "\n\n", }, - "assistant": { - "pre_message": "### Response:\n", - "post_message": "\n\n" - } + "assistant": {"pre_message": "### Response:\n", "post_message": "\n\n"}, }, bos_token="", eos_token="", - messages=messages + messages=messages, ) return prompt + # Llama2 prompt template def llama_2_chat_pt(messages): prompt = custom_prompt( role_dict={ "system": { "pre_message": "[INST] <>\n", - "post_message": "\n<>\n [/INST]\n" + "post_message": "\n<>\n [/INST]\n", }, - "user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348 + "user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348 "pre_message": "[INST] ", - "post_message": " [/INST]\n" - }, + "post_message": " [/INST]\n", + }, "assistant": { - "post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama - } + "post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama + }, }, messages=messages, bos_token="", - eos_token="" + eos_token="", ) return prompt -def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template - - if "instruct" in model: + +def ollama_pt( + model, messages +): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template + if "instruct" in model: prompt = custom_prompt( role_dict={ - "system": { - "pre_message": "### System:\n", - "post_message": "\n" - }, + "system": {"pre_message": "### System:\n", "post_message": "\n"}, "user": { "pre_message": "### User:\n", "post_message": "\n", - }, + }, "assistant": { "pre_message": "### Response:\n", "post_message": "\n", - } + }, }, final_prompt_value="### Response:", - messages=messages + messages=messages, ) elif "llava" in model: prompt = "" @@ -88,36 +87,31 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf elif element["type"] == "image_url": image_url = element["image_url"]["url"] images.append(image_url) - return { - "prompt": prompt, - "images": images - } - else: - prompt = "".join(m["content"] if isinstance(m['content'], str) is str else "".join(m['content']) for m in messages) + return {"prompt": prompt, "images": images} + else: + prompt = "".join( + m["content"] + if isinstance(m["content"], str) is str + else "".join(m["content"]) + for m in messages + ) return prompt -def mistral_instruct_pt(messages): + +def mistral_instruct_pt(messages): prompt = custom_prompt( initial_prompt_value="", role_dict={ - "system": { - "pre_message": "[INST]", - "post_message": "[/INST]" - }, - "user": { - "pre_message": "[INST]", - "post_message": "[/INST]" - }, - "assistant": { - "pre_message": "[INST]", - "post_message": "[/INST]" - } + "system": {"pre_message": "[INST]", "post_message": "[/INST]"}, + "user": {"pre_message": "[INST]", "post_message": "[/INST]"}, + "assistant": {"pre_message": "[INST]", "post_message": "[/INST]"}, }, final_prompt_value="", - messages=messages + messages=messages, ) return prompt + # Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 def falcon_instruct_pt(messages): prompt = "" @@ -125,11 +119,16 @@ def falcon_instruct_pt(messages): if message["role"] == "system": prompt += message["content"] else: - prompt += message['role']+":"+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n") + prompt += ( + message["role"] + + ":" + + message["content"].replace("\r\n", "\n").replace("\n\n", "\n") + ) prompt += "\n\n" - + return prompt + def falcon_chat_pt(messages): prompt = "" for message in messages: @@ -142,6 +141,7 @@ def falcon_chat_pt(messages): return prompt + # MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 def mpt_chat_pt(messages): prompt = "" @@ -154,18 +154,20 @@ def mpt_chat_pt(messages): prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n" return prompt + # WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format def wizardcoder_pt(messages): prompt = "" for message in messages: if message["role"] == "system": prompt += message["content"] + "\n\n" - elif message["role"] == "user": # map to 'Instruction' + elif message["role"] == "user": # map to 'Instruction' prompt += "### Instruction:\n" + message["content"] + "\n\n" - elif message["role"] == "assistant": # map to 'Response' + elif message["role"] == "assistant": # map to 'Response' prompt += "### Response:\n" + message["content"] + "\n\n" return prompt - + + # Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model def phind_codellama_pt(messages): prompt = "" @@ -178,13 +180,17 @@ def phind_codellama_pt(messages): prompt += "### Assistant\n" + message["content"] + "\n\n" return prompt -def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None): + +def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = None): ## get the tokenizer config from huggingface bos_token = "" eos_token = "" - if chat_template is None: + if chat_template is None: + def _get_tokenizer_config(hf_model_name): - url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json" + url = ( + f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json" + ) # Make a GET request to fetch the JSON data response = requests.get(url) if response.status_code == 200: @@ -193,10 +199,14 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No return {"status": "success", "tokenizer": tokenizer_config} else: return {"status": "failure"} + tokenizer_config = _get_tokenizer_config(model) - if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]: + if ( + tokenizer_config["status"] == "failure" + or "chat_template" not in tokenizer_config["tokenizer"] + ): raise Exception("No chat template found") - ## read the bos token, eos token and chat template from the json + ## read the bos token, eos token and chat template from the json tokenizer_config = tokenizer_config["tokenizer"] bos_token = tokenizer_config["bos_token"] eos_token = tokenizer_config["eos_token"] @@ -204,10 +214,10 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No def raise_exception(message): raise Exception(f"Error message - {message}") - + # Create a template object from the template text env = Environment() - env.globals['raise_exception'] = raise_exception + env.globals["raise_exception"] = raise_exception try: template = env.from_string(chat_template) except Exception as e: @@ -216,137 +226,167 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No def _is_system_in_template(): try: # Try rendering the template with a system message - response = template.render(messages=[{"role": "system", "content": "test"}], eos_token= "", bos_token= "") + response = template.render( + messages=[{"role": "system", "content": "test"}], + eos_token="", + bos_token="", + ) return True # This will be raised if Jinja attempts to render the system message and it can't except: return False - - try: + + try: # Render the template with the provided values - if _is_system_in_template(): - rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=messages) - else: + if _is_system_in_template(): + rendered_text = template.render( + bos_token=bos_token, eos_token=eos_token, messages=messages + ) + else: # treat a system message as a user message, if system not in template try: reformatted_messages = [] - for message in messages: - if message["role"] == "system": - reformatted_messages.append({"role": "user", "content": message["content"]}) + for message in messages: + if message["role"] == "system": + reformatted_messages.append( + {"role": "user", "content": message["content"]} + ) else: reformatted_messages.append(message) - rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=reformatted_messages) + rendered_text = template.render( + bos_token=bos_token, + eos_token=eos_token, + messages=reformatted_messages, + ) except Exception as e: - if "Conversation roles must alternate user/assistant" in str(e): + if "Conversation roles must alternate user/assistant" in str(e): # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility new_messages = [] - for i in range(len(reformatted_messages)-1): + for i in range(len(reformatted_messages) - 1): new_messages.append(reformatted_messages[i]) - if reformatted_messages[i]["role"] == reformatted_messages[i+1]["role"]: + if ( + reformatted_messages[i]["role"] + == reformatted_messages[i + 1]["role"] + ): if reformatted_messages[i]["role"] == "user": - new_messages.append({"role": "assistant", "content": ""}) + new_messages.append( + {"role": "assistant", "content": ""} + ) else: new_messages.append({"role": "user", "content": ""}) new_messages.append(reformatted_messages[-1]) - rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages) + rendered_text = template.render( + bos_token=bos_token, eos_token=eos_token, messages=new_messages + ) return rendered_text - except Exception as e: + except Exception as e: raise Exception(f"Error rendering template - {str(e)}") -# Anthropic template -def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts + +# Anthropic template +def claude_2_1_pt( + messages: list, +): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts """ - Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human: + Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human: - you can't just pass a system message - - you can't pass a system message and follow that with an assistant message + - you can't pass a system message and follow that with an assistant message if system message is passed in, you can only do system, human, assistant or system, human - if a system message is passed in and followed by an assistant message, insert a blank human message between them. + if a system message is passed in and followed by an assistant message, insert a blank human message between them. """ + class AnthropicConstants(Enum): HUMAN_PROMPT = "\n\nHuman: " AI_PROMPT = "\n\nAssistant: " - - prompt = "" - for idx, message in enumerate(messages): + + prompt = "" + for idx, message in enumerate(messages): if message["role"] == "user": - prompt += ( - f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" - ) + prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" elif message["role"] == "system": - prompt += ( - f"{message['content']}" - ) + prompt += f"{message['content']}" elif message["role"] == "assistant": - if idx > 0 and messages[idx - 1]["role"] == "system": - prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message - prompt += ( - f"{AnthropicConstants.AI_PROMPT.value}{message['content']}" - ) - prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn + if idx > 0 and messages[idx - 1]["role"] == "system": + prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message + prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}" + prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn return prompt -### TOGETHER AI + +### TOGETHER AI + def get_model_info(token, model): - try: - headers = { - 'Authorization': f'Bearer {token}' - } - response = requests.get('https://api.together.xyz/models/info', headers=headers) + try: + headers = {"Authorization": f"Bearer {token}"} + response = requests.get("https://api.together.xyz/models/info", headers=headers) if response.status_code == 200: model_info = response.json() - for m in model_info: - if m["name"].lower().strip() == model.strip(): - return m['config'].get('prompt_format', None), m['config'].get('chat_template', None) + for m in model_info: + if m["name"].lower().strip() == model.strip(): + return m["config"].get("prompt_format", None), m["config"].get( + "chat_template", None + ) return None, None else: return None, None - except Exception as e: # safely fail a prompt template request + except Exception as e: # safely fail a prompt template request return None, None + def format_prompt_togetherai(messages, prompt_format, chat_template): if prompt_format is None: return default_pt(messages) - - human_prompt, assistant_prompt = prompt_format.split('{prompt}') + + human_prompt, assistant_prompt = prompt_format.split("{prompt}") if chat_template is not None: - prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template) - elif prompt_format is not None: - prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt) - else: + prompt = hf_chat_template( + model=None, messages=messages, chat_template=chat_template + ) + elif prompt_format is not None: + prompt = custom_prompt( + role_dict={}, + messages=messages, + initial_prompt_value=human_prompt, + final_prompt_value=assistant_prompt, + ) + else: prompt = default_pt(messages) - return prompt + return prompt + ### -def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post + +def anthropic_pt( + messages: list, +): # format - https://docs.anthropic.com/claude/reference/complete_post class AnthropicConstants(Enum): HUMAN_PROMPT = "\n\nHuman: " AI_PROMPT = "\n\nAssistant: " - - prompt = "" - for idx, message in enumerate(messages): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: ` + + prompt = "" + for idx, message in enumerate( + messages + ): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: ` if message["role"] == "user": - prompt += ( - f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" - ) + prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" elif message["role"] == "system": - prompt += ( - f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" - ) + prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" else: - prompt += ( - f"{AnthropicConstants.AI_PROMPT.value}{message['content']}" - ) - if idx == 0 and message["role"] == "assistant": # ensure the prompt always starts with `\n\nHuman: ` + prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}" + if ( + idx == 0 and message["role"] == "assistant" + ): # ensure the prompt always starts with `\n\nHuman: ` prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt prompt += f"{AnthropicConstants.AI_PROMPT.value}" - return prompt + return prompt -def gemini_text_image_pt(messages: list): + +def gemini_text_image_pt(messages: list): """ { "contents":[ @@ -367,13 +407,15 @@ def gemini_text_image_pt(messages: list): try: import google.generativeai as genai except: - raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai") - + raise Exception( + "Importing google.generativeai failed, please run 'pip install -q google-generativeai" + ) + prompt = "" - images = [] - for message in messages: + images = [] + for message in messages: if isinstance(message["content"], str): - prompt += message["content"] + prompt += message["content"] elif isinstance(message["content"], list): # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models for element in message["content"]: @@ -383,45 +425,63 @@ def gemini_text_image_pt(messages: list): elif element["type"] == "image_url": image_url = element["image_url"]["url"] images.append(image_url) - + content = [prompt] + images return content -# Function call template + +# Function call template def function_call_prompt(messages: list, functions: list): - function_prompt = "Produce JSON OUTPUT ONLY! The following functions are available to you:" - for function in functions: + function_prompt = ( + "Produce JSON OUTPUT ONLY! The following functions are available to you:" + ) + for function in functions: function_prompt += f"""\n{function}\n""" - + function_added_to_prompt = False - for message in messages: - if "system" in message["role"]: - message['content'] += f"""{function_prompt}""" + for message in messages: + if "system" in message["role"]: + message["content"] += f"""{function_prompt}""" function_added_to_prompt = True - - if function_added_to_prompt == False: - messages.append({'role': 'system', 'content': f"""{function_prompt}"""}) + + if function_added_to_prompt == False: + messages.append({"role": "system", "content": f"""{function_prompt}"""}) return messages # Custom prompt template -def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str="", bos_token: str="", eos_token: str=""): +def custom_prompt( + role_dict: dict, + messages: list, + initial_prompt_value: str = "", + final_prompt_value: str = "", + bos_token: str = "", + eos_token: str = "", +): prompt = bos_token + initial_prompt_value bos_open = True ## a bos token is at the start of a system / human message ## an eos token is at the end of the assistant response to the message for message in messages: role = message["role"] - + if role in ["system", "human"] and not bos_open: prompt += bos_token bos_open = True - - pre_message_str = role_dict[role]["pre_message"] if role in role_dict and "pre_message" in role_dict[role] else "" - post_message_str = role_dict[role]["post_message"] if role in role_dict and "post_message" in role_dict[role] else "" + + pre_message_str = ( + role_dict[role]["pre_message"] + if role in role_dict and "pre_message" in role_dict[role] + else "" + ) + post_message_str = ( + role_dict[role]["post_message"] + if role in role_dict and "post_message" in role_dict[role] + else "" + ) prompt += pre_message_str + message["content"] + post_message_str - + if role == "assistant": prompt += eos_token bos_open = False @@ -429,25 +489,35 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", prompt += final_prompt_value return prompt -def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None): + +def prompt_factory( + model: str, + messages: list, + custom_llm_provider: Optional[str] = None, + api_key: Optional[str] = None, +): original_model_name = model model = model.lower() - if custom_llm_provider == "ollama": + if custom_llm_provider == "ollama": return ollama_pt(model=model, messages=messages) elif custom_llm_provider == "anthropic": - if "claude-2.1" in model: + if "claude-2.1" in model: return claude_2_1_pt(messages=messages) - else: + else: return anthropic_pt(messages=messages) - elif custom_llm_provider == "together_ai": + elif custom_llm_provider == "together_ai": prompt_format, chat_template = get_model_info(token=api_key, model=model) - return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template) - elif custom_llm_provider == "gemini": + return format_prompt_togetherai( + messages=messages, prompt_format=prompt_format, chat_template=chat_template + ) + elif custom_llm_provider == "gemini": return gemini_text_image_pt(messages=messages) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages) - elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template. + elif ( + "tiiuae/falcon" in model + ): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template. if model == "tiiuae/falcon-180B-chat": return falcon_chat_pt(messages=messages) elif "instruct" in model: @@ -457,17 +527,26 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str return mpt_chat_pt(messages=messages) elif "codellama/codellama" in model or "togethercomputer/codellama" in model: if "instruct" in model: - return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions + return llama_2_chat_pt( + messages=messages + ) # https://huggingface.co/blog/codellama#conversational-instructions elif "wizardlm/wizardcoder" in model: return wizardcoder_pt(messages=messages) elif "phind/phind-codellama" in model: return phind_codellama_pt(messages=messages) - elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model): + elif "togethercomputer/llama-2" in model and ( + "instruct" in model or "chat" in model + ): return llama_2_chat_pt(messages=messages) - elif model in ["gryphe/mythomax-l2-13b", "gryphe/mythomix-l2-13b", "gryphe/mythologic-l2-13b"]: - return alpaca_pt(messages=messages) - else: + elif model in [ + "gryphe/mythomax-l2-13b", + "gryphe/mythomix-l2-13b", + "gryphe/mythologic-l2-13b", + ]: + return alpaca_pt(messages=messages) + else: return hf_chat_template(original_model_name, messages) except Exception as e: - return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2) - + return default_pt( + messages=messages + ) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2) diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index b952193e6..bc296bfd0 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -4,81 +4,100 @@ import requests import time from typing import Callable, Optional from litellm.utils import ModelResponse, Usage -import litellm +import litellm import httpx from .prompt_templates.factory import prompt_factory, custom_prompt + class ReplicateError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url="https://api.replicate.com/v1/deployments") + self.request = httpx.Request( + method="POST", url="https://api.replicate.com/v1/deployments" + ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class ReplicateConfig(): + +class ReplicateConfig: """ Reference: https://replicate.com/meta/llama-2-70b-chat/api - `prompt` (string): The prompt to send to the model. - + - `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`. - + - `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`. - + - `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`. - + - `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`. - + - `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`. - + - `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`. - + - `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting ',' will cease generation at the first occurrence of either 'end' or ''. - + - `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed. - + - `debug` (boolean): If set to `True`, it provides debugging output in logs. Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models. """ - system_prompt: Optional[str]=None - max_new_tokens: Optional[int]=None - min_new_tokens: Optional[int]=None - temperature: Optional[int]=None - top_p: Optional[int]=None - top_k: Optional[int]=None - stop_sequences: Optional[str]=None - seed: Optional[int]=None - debug: Optional[bool]=None - def __init__(self, - system_prompt: Optional[str]=None, - max_new_tokens: Optional[int]=None, - min_new_tokens: Optional[int]=None, - temperature: Optional[int]=None, - top_p: Optional[int]=None, - top_k: Optional[int]=None, - stop_sequences: Optional[str]=None, - seed: Optional[int]=None, - debug: Optional[bool]=None) -> None: + system_prompt: Optional[str] = None + max_new_tokens: Optional[int] = None + min_new_tokens: Optional[int] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + top_k: Optional[int] = None + stop_sequences: Optional[str] = None + seed: Optional[int] = None + debug: Optional[bool] = None + + def __init__( + self, + system_prompt: Optional[str] = None, + max_new_tokens: Optional[int] = None, + min_new_tokens: Optional[int] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + top_k: Optional[int] = None, + stop_sequences: Optional[str] = None, + seed: Optional[int] = None, + debug: Optional[bool] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} - + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } # Function to start a prediction and get the prediction URL -def start_prediction(version_id, input_data, api_token, api_base, logging_obj, print_verbose): +def start_prediction( + version_id, input_data, api_token, api_base, logging_obj, print_verbose +): base_url = api_base if "deployments" in version_id: print_verbose("\nLiteLLM: Request to custom replicate deployment") @@ -88,7 +107,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p headers = { "Authorization": f"Token {api_token}", - "Content-Type": "application/json" + "Content-Type": "application/json", } initial_prediction_data = { @@ -98,24 +117,33 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p ## LOGGING logging_obj.pre_call( - input=input_data["prompt"], - api_key="", - additional_args={"complete_input_dict": initial_prediction_data, "headers": headers, "api_base": base_url}, + input=input_data["prompt"], + api_key="", + additional_args={ + "complete_input_dict": initial_prediction_data, + "headers": headers, + "api_base": base_url, + }, ) - response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers) + response = requests.post( + f"{base_url}/predictions", json=initial_prediction_data, headers=headers + ) if response.status_code == 201: response_data = response.json() return response_data.get("urls", {}).get("get") else: - raise ReplicateError(response.status_code, f"Failed to start prediction {response.text}") + raise ReplicateError( + response.status_code, f"Failed to start prediction {response.text}" + ) + # Function to handle prediction response (non-streaming) def handle_prediction_response(prediction_url, api_token, print_verbose): output_string = "" headers = { "Authorization": f"Token {api_token}", - "Content-Type": "application/json" + "Content-Type": "application/json", } status = "" @@ -127,18 +155,22 @@ def handle_prediction_response(prediction_url, api_token, print_verbose): if response.status_code == 200: response_data = response.json() if "output" in response_data: - output_string = "".join(response_data['output']) + output_string = "".join(response_data["output"]) print_verbose(f"Non-streamed output:{output_string}") - status = response_data.get('status', None) + status = response_data.get("status", None) logs = response_data.get("logs", "") if status == "failed": replicate_error = response_data.get("error", "") - raise ReplicateError(status_code=400, message=f"Error: {replicate_error}, \nReplicate logs:{logs}") + raise ReplicateError( + status_code=400, + message=f"Error: {replicate_error}, \nReplicate logs:{logs}", + ) else: # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" print_verbose("Replicate: Failed to fetch prediction status and output.") return output_string, logs + # Function to handle prediction response (streaming) def handle_prediction_response_streaming(prediction_url, api_token, print_verbose): previous_output = "" @@ -146,30 +178,34 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos headers = { "Authorization": f"Token {api_token}", - "Content-Type": "application/json" + "Content-Type": "application/json", } status = "" while True and (status not in ["succeeded", "failed", "canceled"]): - time.sleep(0.5) # prevent being rate limited by replicate + time.sleep(0.5) # prevent being rate limited by replicate print_verbose(f"replicate: polling endpoint: {prediction_url}") response = requests.get(prediction_url, headers=headers) if response.status_code == 200: response_data = response.json() - status = response_data['status'] + status = response_data["status"] if "output" in response_data: - output_string = "".join(response_data['output']) - new_output = output_string[len(previous_output):] + output_string = "".join(response_data["output"]) + new_output = output_string[len(previous_output) :] print_verbose(f"New chunk: {new_output}") yield {"output": new_output, "status": status} previous_output = output_string - status = response_data['status'] + status = response_data["status"] if status == "failed": replicate_error = response_data.get("error", "") - raise ReplicateError(status_code=400, message=f"Error: {replicate_error}") + raise ReplicateError( + status_code=400, message=f"Error: {replicate_error}" + ) else: # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" - print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}") - + print_verbose( + f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}" + ) + # Function to extract version ID from model string def model_to_version_id(model): @@ -178,11 +214,12 @@ def model_to_version_id(model): return split_model[1] return model + # Main function for prediction completion def completion( model: str, messages: list, - api_base: str, + api_base: str, model_response: ModelResponse, print_verbose: Callable, logging_obj, @@ -196,35 +233,37 @@ def completion( # Start a prediction and get the prediction URL version_id = model_to_version_id(model) ## Load Config - config = litellm.ReplicateConfig.get_config() - for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.ReplicateConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v - + system_prompt = None if optional_params is not None and "supports_system_prompt" in optional_params: supports_sys_prompt = optional_params.pop("supports_system_prompt") else: supports_sys_prompt = False - + if supports_sys_prompt: for i in range(len(messages)): if messages[i]["role"] == "system": first_sys_message = messages.pop(i) system_prompt = first_sys_message["content"] break - + if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", {}), - initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - bos_token=model_prompt_details.get("bos_token", ""), - eos_token=model_prompt_details.get("eos_token", ""), - messages=messages, - ) + role_dict=model_prompt_details.get("roles", {}), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + bos_token=model_prompt_details.get("bos_token", ""), + eos_token=model_prompt_details.get("eos_token", ""), + messages=messages, + ) else: prompt = prompt_factory(model=model, messages=messages) @@ -233,43 +272,58 @@ def completion( input_data = { "prompt": prompt, "system_prompt": system_prompt, - **optional_params + **optional_params, } # Otherwise, use the prompt as is else: - input_data = { - "prompt": prompt, - **optional_params - } - + input_data = {"prompt": prompt, **optional_params} ## COMPLETION CALL ## Replicate Compeltion calls have 2 steps ## Step1: Start Prediction: gets a prediction url ## Step2: Poll prediction url for response ## Step2: is handled with and without streaming - model_response["created"] = int(time.time()) # for pricing this must remain right before calling api - prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj, print_verbose=print_verbose) + model_response["created"] = int( + time.time() + ) # for pricing this must remain right before calling api + prediction_url = start_prediction( + version_id, + input_data, + api_key, + api_base, + logging_obj=logging_obj, + print_verbose=print_verbose, + ) print_verbose(prediction_url) # Handle the prediction response (streaming or non-streaming) if "stream" in optional_params and optional_params["stream"] == True: print_verbose("streaming request") - return handle_prediction_response_streaming(prediction_url, api_key, print_verbose) + return handle_prediction_response_streaming( + prediction_url, api_key, print_verbose + ) else: - result, logs = handle_prediction_response(prediction_url, api_key, print_verbose) - model_response["ended"] = time.time() # for pricing this must remain right after calling api + result, logs = handle_prediction_response( + prediction_url, api_key, print_verbose + ) + model_response[ + "ended" + ] = time.time() # for pricing this must remain right after calling api ## LOGGING logging_obj.post_call( - input=prompt, - api_key="", - original_response=result, - additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, }, + input=prompt, + api_key="", + original_response=result, + additional_args={ + "complete_input_dict": input_data, + "logs": logs, + "api_base": prediction_url, + }, ) print_verbose(f"raw model_response: {result}") - if len(result) == 0: # edge case, where result from replicate is empty + if len(result) == 0: # edge case, where result from replicate is empty result = " " ## Building RESPONSE OBJECT @@ -278,12 +332,14 @@ def completion( # Calculate usage prompt_tokens = len(encoding.encode(prompt)) - completion_tokens = len(encoding.encode(model_response["choices"][0]["message"].get("content", ""))) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) model_response["model"] = "replicate/" + model usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage return model_response diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 2bfa9f82a..00af132e8 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -11,42 +11,61 @@ from copy import deepcopy import httpx from .prompt_templates.factory import prompt_factory, custom_prompt + class SagemakerError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker") + self.request = httpx.Request( + method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker" + ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class SagemakerConfig(): + +class SagemakerConfig: """ Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb """ - max_new_tokens: Optional[int]=None - top_p: Optional[float]=None - temperature: Optional[float]=None - return_full_text: Optional[bool]=None - def __init__(self, - max_new_tokens: Optional[int]=None, - top_p: Optional[float]=None, - temperature: Optional[float]=None, - return_full_text: Optional[bool]=None) -> None: + max_new_tokens: Optional[int] = None + top_p: Optional[float] = None + temperature: Optional[float] = None + return_full_text: Optional[bool] = None + + def __init__( + self, + max_new_tokens: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + return_full_text: Optional[bool] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} - + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + """ SAGEMAKER AUTH Keys/Vars os.environ['AWS_ACCESS_KEY_ID'] = "" @@ -55,6 +74,7 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = "" # set os.environ['AWS_REGION_NAME'] = + def completion( model: str, messages: list, @@ -85,28 +105,30 @@ def completion( region_name=aws_region_name, ) else: - # aws_access_key_id is None, assume user is trying to auth using env variables + # aws_access_key_id is None, assume user is trying to auth using env variables # boto3 automaticaly reads env variables # we need to read region name from env - # I assume majority of users use .env for auth + # I assume majority of users use .env for auth region_name = ( - get_secret("AWS_REGION_NAME") or - "us-west-2" # default to us-west-2 if user not specified + get_secret("AWS_REGION_NAME") + or "us-west-2" # default to us-west-2 if user not specified ) client = boto3.client( service_name="sagemaker-runtime", region_name=region_name, ) - + # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker inference_params = deepcopy(optional_params) inference_params.pop("stream", None) ## Load Config - config = litellm.SagemakerConfig.get_config() - for k, v in config.items(): - if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.SagemakerConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v model = model @@ -114,25 +136,26 @@ def completion( # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", None), - initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - messages=messages + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages, ) else: if hf_model_name is None: - if "llama-2" in model.lower(): # llama-2 model - if "chat" in model.lower(): # apply llama2 chat template + if "llama-2" in model.lower(): # llama-2 model + if "chat" in model.lower(): # apply llama2 chat template hf_model_name = "meta-llama/Llama-2-7b-chat-hf" - else: # apply regular llama2 template + else: # apply regular llama2 template hf_model_name = "meta-llama/Llama-2-7b" - hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) + hf_model_name = ( + hf_model_name or model + ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) prompt = prompt_factory(model=hf_model_name, messages=messages) - data = json.dumps({ - "inputs": prompt, - "parameters": inference_params - }).encode('utf-8') + data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode( + "utf-8" + ) ## LOGGING request_str = f""" @@ -142,31 +165,35 @@ def completion( Body={data}, CustomAttributes="accept_eula=true", ) - """ # type: ignore + """ # type: ignore logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={"complete_input_dict": data, "request_str": request_str, "hf_model_name": hf_model_name}, - ) + input=prompt, + api_key="", + additional_args={ + "complete_input_dict": data, + "request_str": request_str, + "hf_model_name": hf_model_name, + }, + ) ## COMPLETION CALL - try: + try: response = client.invoke_endpoint( EndpointName=model, ContentType="application/json", Body=data, CustomAttributes="accept_eula=true", ) - except Exception as e: + except Exception as e: raise SagemakerError(status_code=500, message=f"{str(e)}") - + response = response["Body"].read().decode("utf8") ## LOGGING logging_obj.post_call( - input=prompt, - api_key="", - original_response=response, - additional_args={"complete_input_dict": data}, - ) + input=prompt, + api_key="", + original_response=response, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response}") ## RESPONSE OBJECT completion_response = json.loads(response) @@ -177,19 +204,20 @@ def completion( completion_output += completion_response_choices["generation"] elif "generated_text" in completion_response_choices: completion_output += completion_response_choices["generated_text"] - - # check if the prompt template is part of output, if so - filter it out + + # check if the prompt template is part of output, if so - filter it out if completion_output.startswith(prompt) and "" in prompt: completion_output = completion_output.replace(prompt, "", 1) model_response["choices"][0]["message"]["content"] = completion_output except: - raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500) + raise SagemakerError( + message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", + status_code=500, + ) - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len( - encoding.encode(prompt) - ) + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( encoding.encode(model_response["choices"][0]["message"].get("content", "")) ) @@ -197,28 +225,32 @@ def completion( model_response["created"] = int(time.time()) model_response["model"] = model usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens - ) + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) model_response.usage = usage return model_response -def embedding(model: str, - input: list, - model_response: EmbeddingResponse, - print_verbose: Callable, - encoding, - logging_obj, - custom_prompt_dict={}, - optional_params=None, - litellm_params=None, - logger_fn=None): + +def embedding( + model: str, + input: list, + model_response: EmbeddingResponse, + print_verbose: Callable, + encoding, + logging_obj, + custom_prompt_dict={}, + optional_params=None, + litellm_params=None, + logger_fn=None, +): """ - Supports Huggingface Jumpstart embeddings like GPT-6B + Supports Huggingface Jumpstart embeddings like GPT-6B """ ### BOTO3 INIT - import boto3 + import boto3 + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) @@ -234,34 +266,34 @@ def embedding(model: str, region_name=aws_region_name, ) else: - # aws_access_key_id is None, assume user is trying to auth using env variables + # aws_access_key_id is None, assume user is trying to auth using env variables # boto3 automaticaly reads env variables # we need to read region name from env - # I assume majority of users use .env for auth + # I assume majority of users use .env for auth region_name = ( - get_secret("AWS_REGION_NAME") or - "us-west-2" # default to us-west-2 if user not specified + get_secret("AWS_REGION_NAME") + or "us-west-2" # default to us-west-2 if user not specified ) client = boto3.client( service_name="sagemaker-runtime", region_name=region_name, ) - + # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker inference_params = deepcopy(optional_params) inference_params.pop("stream", None) ## Load Config - config = litellm.SagemakerConfig.get_config() - for k, v in config.items(): - if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.SagemakerConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v - #### HF EMBEDDING LOGIC - data = json.dumps({ - "text_inputs": input - }).encode('utf-8') + #### HF EMBEDDING LOGIC + data = json.dumps({"text_inputs": input}).encode("utf-8") ## LOGGING request_str = f""" @@ -270,67 +302,65 @@ def embedding(model: str, ContentType="application/json", Body={data}, CustomAttributes="accept_eula=true", - )""" # type: ignore + )""" # type: ignore logging_obj.pre_call( - input=input, - api_key="", - additional_args={"complete_input_dict": data, "request_str": request_str}, - ) + input=input, + api_key="", + additional_args={"complete_input_dict": data, "request_str": request_str}, + ) ## EMBEDDING CALL - try: + try: response = client.invoke_endpoint( EndpointName=model, ContentType="application/json", Body=data, CustomAttributes="accept_eula=true", ) - except Exception as e: + except Exception as e: raise SagemakerError(status_code=500, message=f"{str(e)}") ## LOGGING logging_obj.post_call( - input=input, - api_key="", - additional_args={"complete_input_dict": data}, - original_response=response, - ) - + input=input, + api_key="", + additional_args={"complete_input_dict": data}, + original_response=response, + ) response = json.loads(response["Body"].read().decode("utf8")) ## LOGGING logging_obj.post_call( - input=input, - api_key="", - original_response=response, - additional_args={"complete_input_dict": data}, - ) + input=input, + api_key="", + original_response=response, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response}") - if "embedding" not in response: + if "embedding" not in response: raise SagemakerError(status_code=500, message="embedding not found in response") - embeddings = response['embedding'] + embeddings = response["embedding"] if not isinstance(embeddings, list): - raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}") - + raise SagemakerError( + status_code=422, message=f"Response not in expected format - {embeddings}" + ) output_data = [] for idx, embedding in enumerate(embeddings): output_data.append( - { - "object": "embedding", - "index": idx, - "embedding": embedding - } + {"object": "embedding", "index": idx, "embedding": embedding} ) model_response["object"] = "list" model_response["data"] = output_data model_response["model"] = model - + input_tokens = 0 for text in input: - input_tokens+=len(encoding.encode(text)) + input_tokens += len(encoding.encode(text)) + + model_response["usage"] = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) - model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens) - return model_response diff --git a/litellm/llms/together_ai.py b/litellm/llms/together_ai.py index 540dbe202..e8a74724b 100644 --- a/litellm/llms/together_ai.py +++ b/litellm/llms/together_ai.py @@ -9,17 +9,21 @@ import httpx from litellm.utils import ModelResponse, Usage from .prompt_templates.factory import prompt_factory, custom_prompt + class TogetherAIError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url="https://api.together.xyz/inference") + self.request = httpx.Request( + method="POST", url="https://api.together.xyz/inference" + ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class TogetherAIConfig(): + +class TogetherAIConfig: """ Reference: https://docs.together.ai/reference/inference @@ -37,35 +41,49 @@ class TogetherAIConfig(): - `repetition_penalty` (float, optional): A number that controls the diversity of generated text by reducing the likelihood of repeated sequences. Higher values decrease repetition. - - `logprobs` (int32, optional): This parameter is not described in the prompt. + - `logprobs` (int32, optional): This parameter is not described in the prompt. """ - max_tokens: Optional[int]=None - stop: Optional[str]=None - temperature:Optional[int]=None - top_p: Optional[float]=None - top_k: Optional[int]=None - repetition_penalty: Optional[float]=None - logprobs: Optional[int]=None - - def __init__(self, - max_tokens: Optional[int]=None, - stop: Optional[str]=None, - temperature:Optional[int]=None, - top_p: Optional[float]=None, - top_k: Optional[int]=None, - repetition_penalty: Optional[float]=None, - logprobs: Optional[int]=None) -> None: + + max_tokens: Optional[int] = None + stop: Optional[str] = None + temperature: Optional[int] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + repetition_penalty: Optional[float] = None + logprobs: Optional[int] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + stop: Optional[str] = None, + temperature: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + repetition_penalty: Optional[float] = None, + logprobs: Optional[int] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } def validate_environment(api_key): @@ -80,10 +98,11 @@ def validate_environment(api_key): } return headers + def completion( model: str, messages: list, - api_base: str, + api_base: str, model_response: ModelResponse, print_verbose: Callable, encoding, @@ -97,9 +116,11 @@ def completion( headers = validate_environment(api_key) ## Load Config - config = litellm.TogetherAIConfig.get_config() - for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in + config = litellm.TogetherAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}") @@ -107,15 +128,20 @@ def completion( # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", {}), - initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - bos_token=model_prompt_details.get("bos_token", ""), - eos_token=model_prompt_details.get("eos_token", ""), - messages=messages, - ) + role_dict=model_prompt_details.get("roles", {}), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + bos_token=model_prompt_details.get("bos_token", ""), + eos_token=model_prompt_details.get("eos_token", ""), + messages=messages, + ) else: - prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list + prompt = prompt_factory( + model=model, + messages=messages, + api_key=api_key, + custom_llm_provider="together_ai", + ) # api key required to query together ai model list data = { "model": model, @@ -128,13 +154,14 @@ def completion( logging_obj.pre_call( input=prompt, api_key=api_key, - additional_args={"complete_input_dict": data, "headers": headers, "api_base": api_base}, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": api_base, + }, ) ## COMPLETION CALL - if ( - "stream_tokens" in optional_params - and optional_params["stream_tokens"] == True - ): + if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: response = requests.post( api_base, headers=headers, @@ -143,18 +170,14 @@ def completion( ) return response.iter_lines() else: - response = requests.post( - api_base, - headers=headers, - data=json.dumps(data) - ) + response = requests.post(api_base, headers=headers, data=json.dumps(data)) ## LOGGING logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT if response.status_code != 200: @@ -170,30 +193,38 @@ def completion( ) elif "error" in completion_response["output"]: raise TogetherAIError( - message=json.dumps(completion_response["output"]), status_code=response.status_code + message=json.dumps(completion_response["output"]), + status_code=response.status_code, ) - + if len(completion_response["output"]["choices"][0]["text"]) >= 0: - model_response["choices"][0]["message"]["content"] = completion_response["output"]["choices"][0]["text"] + model_response["choices"][0]["message"]["content"] = completion_response[ + "output" + ]["choices"][0]["text"] ## CALCULATING USAGE - print_verbose(f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}") + print_verbose( + f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}" + ) prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( encoding.encode(model_response["choices"][0]["message"].get("content", "")) ) if "finish_reason" in completion_response["output"]["choices"][0]: - model_response.choices[0].finish_reason = completion_response["output"]["choices"][0]["finish_reason"] + model_response.choices[0].finish_reason = completion_response["output"][ + "choices" + ][0]["finish_reason"] model_response["created"] = int(time.time()) model_response["model"] = model usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage return model_response + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index e25a0b925..0c3b6ff8c 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, Usage, CustomStreamWrapper import litellm import httpx + class VertexAIError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url=" https://cloud.google.com/vertex-ai/") + self.request = httpx.Request( + method="POST", url=" https://cloud.google.com/vertex-ai/" + ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs -class VertexAIConfig(): + +class VertexAIConfig: """ Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts @@ -34,28 +38,42 @@ class VertexAIConfig(): Note: Please make sure to modify the default parameters as required for your use case. """ - temperature: Optional[float]=None - max_output_tokens: Optional[int]=None - top_p: Optional[float]=None - top_k: Optional[int]=None - def __init__(self, - temperature: Optional[float]=None, - max_output_tokens: Optional[int]=None, - top_p: Optional[float]=None, - top_k: Optional[int]=None) -> None: - + temperature: Optional[float] = None + max_output_tokens: Optional[int] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + + def __init__( + self, + temperature: Optional[float] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + ) -> None: locals_ = locals() for key, value in locals_.items(): - if key != 'self' and value is not None: + if key != "self" and value is not None: setattr(self.__class__, key, value) - + @classmethod def get_config(cls): - return {k: v for k, v in cls.__dict__.items() - if not k.startswith('__') - and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) - and v is not None} + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + def _get_image_bytes_from_url(image_url: str) -> bytes: try: @@ -65,7 +83,7 @@ def _get_image_bytes_from_url(image_url: str) -> bytes: return image_bytes except requests.exceptions.RequestException as e: # Handle any request exceptions (e.g., connection error, timeout) - return b'' # Return an empty bytes object or handle the error as needed + return b"" # Return an empty bytes object or handle the error as needed def _load_image_from_url(image_url: str): @@ -78,13 +96,18 @@ def _load_image_from_url(image_url: str): Returns: Image: The loaded image. """ - from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image + from vertexai.preview.generative_models import ( + GenerativeModel, + Part, + GenerationConfig, + Image, + ) + image_bytes = _get_image_bytes_from_url(image_url) return Image.from_bytes(image_bytes) -def _gemini_vision_convert_messages( - messages: list -): + +def _gemini_vision_convert_messages(messages: list): """ Converts given messages for GPT-4 Vision to Gemini format. @@ -95,7 +118,7 @@ def _gemini_vision_convert_messages( Returns: tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images). - + Raises: VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed. Exception: If any other exception occurs during the execution of the function. @@ -115,11 +138,23 @@ def _gemini_vision_convert_messages( try: import vertexai except: - raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`") - try: - from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair + raise VertexAIError( + status_code=400, + message="vertexai import failed please run `pip install google-cloud-aiplatform`", + ) + try: + from vertexai.preview.language_models import ( + ChatModel, + CodeChatModel, + InputOutputTextPair, + ) from vertexai.language_models import TextGenerationModel, CodeGenerationModel - from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image + from vertexai.preview.generative_models import ( + GenerativeModel, + Part, + GenerationConfig, + Image, + ) # given messages for gpt-4 vision, convert them for gemini # https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb @@ -159,6 +194,7 @@ def _gemini_vision_convert_messages( except Exception as e: raise e + def completion( model: str, messages: list, @@ -171,30 +207,38 @@ def completion( optional_params=None, litellm_params=None, logger_fn=None, - acompletion: bool=False + acompletion: bool = False, ): try: import vertexai except: - raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`") - try: - from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair + raise VertexAIError( + status_code=400, + message="vertexai import failed please run `pip install google-cloud-aiplatform`", + ) + try: + from vertexai.preview.language_models import ( + ChatModel, + CodeChatModel, + InputOutputTextPair, + ) from vertexai.language_models import TextGenerationModel, CodeGenerationModel - from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig + from vertexai.preview.generative_models import ( + GenerativeModel, + Part, + GenerationConfig, + ) from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types - - vertexai.init( - project=vertex_project, location=vertex_location - ) + vertexai.init(project=vertex_project, location=vertex_location) ## Load Config config = litellm.VertexAIConfig.get_config() - for k, v in config.items(): - if k not in optional_params: + for k, v in config.items(): + if k not in optional_params: optional_params[k] = v - ## Process safety settings into format expected by vertex AI + ## Process safety settings into format expected by vertex AI safety_settings = None if "safety_settings" in optional_params: safety_settings = optional_params.pop("safety_settings") @@ -202,17 +246,25 @@ def completion( raise ValueError("safety_settings must be a list") if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): raise ValueError("safety_settings must be a list of dicts") - safety_settings=[gapic_content_types.SafetySetting(x) for x in safety_settings] + safety_settings = [ + gapic_content_types.SafetySetting(x) for x in safety_settings + ] # vertexai does not use an API key, it looks for credentials.json in the environment - prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)]) + prompt = " ".join( + [ + message["content"] + for message in messages + if isinstance(message["content"], str) + ] + ) - mode = "" + mode = "" request_str = "" response_obj = None - if model in litellm.vertex_language_models: + if model in litellm.vertex_language_models: llm_model = GenerativeModel(model) mode = "" request_str += f"llm_model = GenerativeModel({model})\n" @@ -232,31 +284,76 @@ def completion( llm_model = CodeGenerationModel.from_pretrained(model) mode = "text" request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" - else: # vertex_code_llm_models + else: # vertex_code_llm_models llm_model = CodeChatModel.from_pretrained(model) mode = "chat" request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" - - if acompletion == True: # [TODO] expand support to vertex ai chat + text models - if optional_params.get("stream", False) is True: + + if acompletion == True: # [TODO] expand support to vertex ai chat + text models + if optional_params.get("stream", False) is True: # async streaming - return async_streaming(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, messages=messages, print_verbose=print_verbose, **optional_params) - return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, messages=messages,print_verbose=print_verbose,**optional_params) + return async_streaming( + llm_model=llm_model, + mode=mode, + prompt=prompt, + logging_obj=logging_obj, + request_str=request_str, + model=model, + model_response=model_response, + messages=messages, + print_verbose=print_verbose, + **optional_params, + ) + return async_completion( + llm_model=llm_model, + mode=mode, + prompt=prompt, + logging_obj=logging_obj, + request_str=request_str, + model=model, + model_response=model_response, + encoding=encoding, + messages=messages, + print_verbose=print_verbose, + **optional_params, + ) if mode == "": - if "stream" in optional_params and optional_params["stream"] == True: stream = optional_params.pop("stream") request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - model_response = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + model_response = llm_model.generate_content( + prompt, + generation_config=GenerationConfig(**optional_params), + safety_settings=safety_settings, + stream=stream, + ) optional_params["stream"] = True return model_response request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n" ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - response_obj = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + response_obj = llm_model.generate_content( + prompt, + generation_config=GenerationConfig(**optional_params), + safety_settings=safety_settings, + ) completion_response = response_obj.text response_obj = response_obj._raw_response elif mode == "vision": @@ -268,21 +365,35 @@ def completion( if "stream" in optional_params and optional_params["stream"] == True: stream = optional_params.pop("stream") request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + model_response = llm_model.generate_content( contents=content, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, - stream=True + stream=True, ) optional_params["stream"] = True return model_response request_str += f"response = llm_model.generate_content({content})\n" ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + ## LLM Call response = llm_model.generate_content( contents=content, @@ -293,88 +404,150 @@ def completion( response_obj = response._raw_response elif mode == "chat": chat = llm_model.start_chat() - request_str+= f"chat = llm_model.start_chat()\n" + request_str += f"chat = llm_model.start_chat()\n" if "stream" in optional_params and optional_params["stream"] == True: # NOTE: VertexAI does not accept stream=True as a param and raises an error, # we handle this by removing 'stream' from optional params and sending the request # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format - optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params - request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n" + optional_params.pop( + "stream", None + ) # vertex ai raises an error when passing stream in optional params + request_str += ( + f"chat.send_message_streaming({prompt}, **{optional_params})\n" + ) ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) model_response = chat.send_message_streaming(prompt, **optional_params) optional_params["stream"] = True return model_response request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) completion_response = chat.send_message(prompt, **optional_params).text elif mode == "text": if "stream" in optional_params and optional_params["stream"] == True: - optional_params.pop("stream", None) # See note above on handling streaming for vertex ai - request_str += f"llm_model.predict_streaming({prompt}, **{optional_params})\n" + optional_params.pop( + "stream", None + ) # See note above on handling streaming for vertex ai + request_str += ( + f"llm_model.predict_streaming({prompt}, **{optional_params})\n" + ) ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) model_response = llm_model.predict_streaming(prompt, **optional_params) optional_params["stream"] = True return model_response request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) completion_response = llm_model.predict(prompt, **optional_params).text - + ## LOGGING logging_obj.post_call( input=prompt, api_key=None, original_response=completion_response ) ## RESPONSE OBJECT - if len(str(completion_response)) > 0: - model_response["choices"][0]["message"][ - "content" - ] = str(completion_response) + if len(str(completion_response)) > 0: + model_response["choices"][0]["message"]["content"] = str( + completion_response + ) model_response["choices"][0]["message"]["content"] = str(completion_response) model_response["created"] = int(time.time()) model_response["model"] = model ## CALCULATING USAGE if model in litellm.vertex_language_models and response_obj is not None: - model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name - usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count, - completion_tokens=response_obj.usage_metadata.candidates_token_count, - total_tokens=response_obj.usage_metadata.total_token_count) - else: - prompt_tokens = len( - encoding.encode(prompt) - ) + model_response["choices"][0].finish_reason = response_obj.candidates[ + 0 + ].finish_reason.name + usage = Usage( + prompt_tokens=response_obj.usage_metadata.prompt_token_count, + completion_tokens=response_obj.usage_metadata.candidates_token_count, + total_tokens=response_obj.usage_metadata.total_token_count, + ) + else: + prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) + encoding.encode( + model_response["choices"][0]["message"].get("content", "") + ) ) usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens - ) + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) model_response.usage = usage return model_response - except Exception as e: + except Exception as e: raise VertexAIError(status_code=500, message=str(e)) -async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, encoding=None, messages = None, print_verbose = None, **optional_params): + +async def async_completion( + llm_model, + mode: str, + prompt: str, + model: str, + model_response: ModelResponse, + logging_obj=None, + request_str=None, + encoding=None, + messages=None, + print_verbose=None, + **optional_params, +): """ Add support for acompletion calls for gemini-pro """ - try: + try: from vertexai.preview.generative_models import GenerationConfig if mode == "": # gemini-pro chat = llm_model.start_chat() ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params)) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + response_obj = await chat.send_message_async( + prompt, generation_config=GenerationConfig(**optional_params) + ) completion_response = response_obj.text response_obj = response_obj._raw_response elif mode == "vision": @@ -386,12 +559,18 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_ request_str += f"response = llm_model.generate_content({content})\n" ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + ## LLM Call response = await llm_model._generate_content_async( - contents=content, - generation_config=GenerationConfig(**optional_params) + contents=content, generation_config=GenerationConfig(**optional_params) ) completion_response = response.text response_obj = response._raw_response @@ -399,14 +578,28 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_ # chat-bison etc. chat = llm_model.start_chat() ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) response_obj = await chat.send_message_async(prompt, **optional_params) completion_response = response_obj.text elif mode == "text": # gecko etc. request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) response_obj = await llm_model.predict_async(prompt, **optional_params) completion_response = response_obj.text @@ -416,51 +609,77 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_ ) ## RESPONSE OBJECT - if len(str(completion_response)) > 0: - model_response["choices"][0]["message"][ - "content" - ] = str(completion_response) + if len(str(completion_response)) > 0: + model_response["choices"][0]["message"]["content"] = str( + completion_response + ) model_response["choices"][0]["message"]["content"] = str(completion_response) model_response["created"] = int(time.time()) model_response["model"] = model ## CALCULATING USAGE if model in litellm.vertex_language_models and response_obj is not None: - model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name - usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count, - completion_tokens=response_obj.usage_metadata.candidates_token_count, - total_tokens=response_obj.usage_metadata.total_token_count) + model_response["choices"][0].finish_reason = response_obj.candidates[ + 0 + ].finish_reason.name + usage = Usage( + prompt_tokens=response_obj.usage_metadata.prompt_token_count, + completion_tokens=response_obj.usage_metadata.candidates_token_count, + total_tokens=response_obj.usage_metadata.total_token_count, + ) else: - prompt_tokens = len( - encoding.encode(prompt) - ) + prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) + encoding.encode( + model_response["choices"][0]["message"].get("content", "") + ) ) usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens - ) + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) model_response.usage = usage return model_response - except Exception as e: + except Exception as e: raise VertexAIError(status_code=500, message=str(e)) -async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, messages = None, print_verbose = None, **optional_params): + +async def async_streaming( + llm_model, + mode: str, + prompt: str, + model: str, + model_response: ModelResponse, + logging_obj=None, + request_str=None, + messages=None, + print_verbose=None, + **optional_params, +): """ Add support for async streaming calls for gemini-pro """ from vertexai.preview.generative_models import GenerationConfig - if mode == "": + + if mode == "": # gemini-pro chat = llm_model.start_chat() stream = optional_params.pop("stream") request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - response = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params), stream=stream) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + response = await chat.send_message_async( + prompt, generation_config=GenerationConfig(**optional_params), stream=stream + ) optional_params["stream"] = True - elif mode == "vision": + elif mode == "vision": stream = optional_params.pop("stream") print_verbose("\nMaking VertexAI Gemini Pro Vision Call") @@ -470,33 +689,68 @@ async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_r content = [prompt] + images stream = optional_params.pop("stream") request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + response = llm_model._generate_content_streaming_async( contents=content, generation_config=GenerationConfig(**optional_params), - stream=True + stream=True, ) optional_params["stream"] = True elif mode == "chat": chat = llm_model.start_chat() - optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params - request_str += f"chat.send_message_streaming_async({prompt}, **{optional_params})\n" + optional_params.pop( + "stream", None + ) # vertex ai raises an error when passing stream in optional params + request_str += ( + f"chat.send_message_streaming_async({prompt}, **{optional_params})\n" + ) ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) response = chat.send_message_streaming_async(prompt, **optional_params) optional_params["stream"] = True elif mode == "text": - optional_params.pop("stream", None) # See note above on handling streaming for vertex ai - request_str += f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n" + optional_params.pop( + "stream", None + ) # See note above on handling streaming for vertex ai + request_str += ( + f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n" + ) ## LOGGING - logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) response = llm_model.predict_streaming_async(prompt, **optional_params) - streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="vertex_ai",logging_obj=logging_obj) + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="vertex_ai", + logging_obj=logging_obj, + ) async for transformed_chunk in streamwrapper: yield transformed_chunk + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/vllm.py b/litellm/llms/vllm.py index 428dc959d..2b130765b 100644 --- a/litellm/llms/vllm.py +++ b/litellm/llms/vllm.py @@ -6,7 +6,10 @@ import time, httpx from typing import Callable, Any from litellm.utils import ModelResponse, Usage from .prompt_templates.factory import prompt_factory, custom_prompt + llm = None + + class VLLMError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -17,17 +20,20 @@ class VLLMError(Exception): self.message ) # Call the base class constructor with the parameters it needs + # check if vllm is installed def validate_environment(model: str): global llm - try: - from vllm import LLM, SamplingParams # type: ignore + try: + from vllm import LLM, SamplingParams # type: ignore + if llm is None: llm = LLM(model=model) return llm, SamplingParams except Exception as e: raise VLLMError(status_code=0, message=str(e)) + def completion( model: str, messages: list, @@ -50,15 +56,14 @@ def completion( # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, ) else: prompt = prompt_factory(model=model, messages=messages) - ## LOGGING logging_obj.pre_call( input=prompt, @@ -69,9 +74,10 @@ def completion( if llm: outputs = llm.generate(prompt, sampling_params) else: - raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm") + raise VLLMError( + status_code=0, message="Need to pass in a model name to initialize vllm" + ) - ## COMPLETION CALL if "stream" in optional_params and optional_params["stream"] == True: return iter(outputs) @@ -88,24 +94,22 @@ def completion( model_response["choices"][0]["message"]["content"] = outputs[0].outputs[0].text ## CALCULATING USAGE - prompt_tokens = len(outputs[0].prompt_token_ids) - completion_tokens = len(outputs[0].outputs[0].token_ids) + prompt_tokens = len(outputs[0].prompt_token_ids) + completion_tokens = len(outputs[0].outputs[0].token_ids) model_response["created"] = int(time.time()) model_response["model"] = model usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage return model_response + def batch_completions( - model: str, - messages: list, - optional_params=None, - custom_prompt_dict={} + model: str, messages: list, optional_params=None, custom_prompt_dict={} ): """ Example usage: @@ -137,31 +141,33 @@ def batch_completions( except Exception as e: error_str = str(e) if "data parallel group is already initialized" in error_str: - pass + pass else: raise VLLMError(status_code=0, message=error_str) sampling_params = SamplingParams(**optional_params) - prompts = [] + prompts = [] if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] for message in messages: prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=message + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=message, ) prompts.append(prompt) else: for message in messages: prompt = prompt_factory(model=model, messages=message) prompts.append(prompt) - + if llm: outputs = llm.generate(prompts, sampling_params) else: - raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm") + raise VLLMError( + status_code=0, message="Need to pass in a model name to initialize vllm" + ) final_outputs = [] for output in outputs: @@ -170,20 +176,21 @@ def batch_completions( model_response["choices"][0]["message"]["content"] = output.outputs[0].text ## CALCULATING USAGE - prompt_tokens = len(output.prompt_token_ids) - completion_tokens = len(output.outputs[0].token_ids) + prompt_tokens = len(output.prompt_token_ids) + completion_tokens = len(output.outputs[0].token_ids) model_response["created"] = int(time.time()) model_response["model"] = model usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage final_outputs.append(model_response) return final_outputs + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/main.py b/litellm/main.py index fb3be6233..af2460a6d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -5,7 +5,7 @@ # | | # +-----------------------------------------------+ # -# Thank you ! We ❤️ you! - Krrish & Ishaan +# Thank you ! We ❤️ you! - Krrish & Ishaan import os, openai, sys, json, inspect, uuid, datetime, threading from typing import Any @@ -29,12 +29,12 @@ from litellm.utils import ( completion_with_fallbacks, get_llm_provider, get_api_key, - mock_completion_streaming_obj, - convert_to_model_response_object, - token_counter, - Usage, + mock_completion_streaming_obj, + convert_to_model_response_object, + token_counter, + Usage, get_optional_params_embeddings, - get_optional_params_image_gen + get_optional_params_image_gen, ) from .llms import ( anthropic, @@ -56,11 +56,16 @@ from .llms import ( palm, gemini, vertex_ai, - maritalk) + maritalk, +) from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion from .llms.huggingface_restapi import Huggingface -from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt +from .llms.prompt_templates.factory import ( + prompt_factory, + custom_prompt, + function_call_prompt, +) import tiktoken from concurrent.futures import ThreadPoolExecutor from typing import Callable, List, Optional, Dict, Union, Mapping @@ -75,8 +80,8 @@ from litellm.utils import ( TextChoices, EmbeddingResponse, read_config_args, - Choices, - Message + Choices, + Message, ) ####### ENVIRONMENT VARIABLES ################### @@ -87,35 +92,39 @@ azure_chat_completions = AzureChatCompletion() huggingface = Huggingface() ####### COMPLETION ENDPOINTS ################ + class LiteLLM: + def __init__( + self, + *, + api_key=None, + organization: Optional[str] = None, + base_url: Optional[str] = None, + timeout: Optional[float] = 600, + max_retries: Optional[int] = litellm.num_retries, + default_headers: Optional[Mapping[str, str]] = None, + ): + self.params = locals() + self.chat = Chat(self.params) - def __init__(self, *, - api_key=None, - organization: Optional[str] = None, - base_url: Optional[str]= None, - timeout: Optional[float] = 600, - max_retries: Optional[int] = litellm.num_retries, - default_headers: Optional[Mapping[str, str]] = None,): - self.params = locals() - self.chat = Chat(self.params) -class Chat(): +class Chat: + def __init__(self, params): + self.params = params + self.completions = Completions(self.params) - def __init__(self, params): - self.params = params - self.completions = Completions(self.params) - -class Completions(): - - def __init__(self, params): - self.params = params - def create(self, messages, model=None, **kwargs): - for k, v in kwargs.items(): - self.params[k] = v - model = model or self.params.get('model') - response = completion(model=model, messages=messages, **self.params) - return response +class Completions: + def __init__(self, params): + self.params = params + + def create(self, messages, model=None, **kwargs): + for k, v in kwargs.items(): + self.params[k] = v + model = model or self.params.get("model") + response = completion(model=model, messages=messages, **self.params) + return response + @client async def acompletion(*args, **kwargs): @@ -139,7 +148,7 @@ async def acompletion(*args, **kwargs): frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. - metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. + metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. api_base (str, optional): Base URL for the API (default is None). api_version (str, optional): API version (default is None). api_key (str, optional): API key (default is None). @@ -159,10 +168,10 @@ async def acompletion(*args, **kwargs): """ loop = asyncio.get_event_loop() model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO COMPLETION ### + ### PASS ARGS TO COMPLETION ### kwargs["acompletion"] = True custom_llm_provider = None - try: + try: # Use a partial function to pass your keyword arguments func = partial(completion, *args, **kwargs) @@ -170,10 +179,13 @@ async def acompletion(*args, **kwargs): ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) - if (custom_llm_provider == "openai" - or custom_llm_provider == "azure" + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "azure" or custom_llm_provider == "custom_openai" or custom_llm_provider == "anyscale" or custom_llm_provider == "mistral" @@ -183,39 +195,58 @@ async def acompletion(*args, **kwargs): or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "huggingface" or custom_llm_provider == "ollama" - or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all. - if kwargs.get("stream", False): + or custom_llm_provider == "vertex_ai" + ): # currently implemented aiohttp calls for just azure and openai, soon all. + if kwargs.get("stream", False): response = completion(*args, **kwargs) else: # Await normally init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO + if isinstance(init_response, dict) or isinstance( + init_response, ModelResponse + ): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): response = await init_response - else: + else: # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) - if kwargs.get("stream", False): # return an async generator - return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args) - else: + response = await loop.run_in_executor(None, func_with_context) + if kwargs.get("stream", False): # return an async generator + return _async_streaming( + response=response, + model=model, + custom_llm_provider=custom_llm_provider, + args=args, + ) + else: return response - except Exception as e: + except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, - ) + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) -async def _async_streaming(response, model, custom_llm_provider, args): - try: + +async def _async_streaming(response, model, custom_llm_provider, args): + try: print_verbose(f"received response in _async_streaming: {response}") - async for line in response: + async for line in response: print_verbose(f"line in async streaming: {line}") yield line - except Exception as e: + except Exception as e: raise e -def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs): + +def mock_completion( + model: str, + messages: List, + stream: Optional[bool] = False, + mock_response: str = "This is a mock request", + **kwargs, +): """ Generate a mock completion response for testing or debugging purposes. @@ -242,9 +273,11 @@ def mock_completion(model: str, messages: List, stream: Optional[bool] = False, model_response = ModelResponse(stream=stream) if stream is True: # don't try to access stream object, - response = mock_completion_streaming_obj(model_response, mock_response=mock_response, model=model) + response = mock_completion_streaming_obj( + model_response, mock_response=mock_response, model=model + ) return response - + model_response["choices"][0]["message"]["content"] = mock_response model_response["created"] = int(time.time()) model_response["model"] = model @@ -254,6 +287,7 @@ def mock_completion(model: str, messages: List, stream: Optional[bool] = False, traceback.print_exc() raise Exception("Mock completion response failed") + @client def completion( model: str, @@ -269,7 +303,7 @@ def completion( stop=None, max_tokens: Optional[float] = None, presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float]=None, + frequency_penalty: Optional[float] = None, logit_bias: Optional[dict] = None, user: Optional[str] = None, # openai v1.0+ new params @@ -277,13 +311,12 @@ def completion( seed: Optional[int] = None, tools: Optional[List] = None, tool_choice: Optional[str] = None, - deployment_id = None, + deployment_id=None, # set api_base, api_version, api_key base_url: Optional[str] = None, api_version: Optional[str] = None, api_key: Optional[str] = None, - model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. - + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. # Optional liteLLM function params **kwargs, ) -> Union[ModelResponse, CustomStreamWrapper]: @@ -306,7 +339,7 @@ def completion( frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. - metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. + metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. api_base (str, optional): Base URL for the API (default is None). api_version (str, optional): API version (default is None). api_key (str, optional): API key (default is None). @@ -326,26 +359,26 @@ def completion( """ ######### unpacking kwargs ##################### args = locals() - api_base = kwargs.get('api_base', None) - mock_response = kwargs.get('mock_response', None) - force_timeout= kwargs.get('force_timeout', 600) ## deprecated - logger_fn = kwargs.get('logger_fn', None) - verbose = kwargs.get('verbose', False) - custom_llm_provider = kwargs.get('custom_llm_provider', None) - litellm_logging_obj = kwargs.get('litellm_logging_obj', None) - id = kwargs.get('id', None) - metadata = kwargs.get('metadata', None) - model_info = kwargs.get('model_info', None) - proxy_server_request = kwargs.get('proxy_server_request', None) - fallbacks = kwargs.get('fallbacks', None) + api_base = kwargs.get("api_base", None) + mock_response = kwargs.get("mock_response", None) + force_timeout = kwargs.get("force_timeout", 600) ## deprecated + logger_fn = kwargs.get("logger_fn", None) + verbose = kwargs.get("verbose", False) + custom_llm_provider = kwargs.get("custom_llm_provider", None) + litellm_logging_obj = kwargs.get("litellm_logging_obj", None) + id = kwargs.get("id", None) + metadata = kwargs.get("metadata", None) + model_info = kwargs.get("model_info", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + fallbacks = kwargs.get("fallbacks", None) headers = kwargs.get("headers", None) - num_retries = kwargs.get("num_retries", None) ## deprecated + num_retries = kwargs.get("num_retries", None) ## deprecated max_retries = kwargs.get("max_retries", None) context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) - ### CUSTOM MODEL COST ### + ### CUSTOM MODEL COST ### input_cost_per_token = kwargs.get("input_cost_per_token", None) output_cost_per_token = kwargs.get("output_cost_per_token", None) - ### CUSTOM PROMPT TEMPLATE ### + ### CUSTOM PROMPT TEMPLATE ### initial_prompt_value = kwargs.get("initial_prompt_value", None) roles = kwargs.get("roles", None) final_prompt_value = kwargs.get("final_prompt_value", None) @@ -353,104 +386,199 @@ def completion( eos_token = kwargs.get("eos_token", None) preset_cache_key = kwargs.get("preset_cache_key", None) hf_model_name = kwargs.get("hf_model_name", None) - ### ASYNC CALLS ### + ### ASYNC CALLS ### acompletion = kwargs.get("acompletion", False) client = kwargs.get("client", None) ######## end of unpacking kwargs ########### - openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] - litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key", "caching_groups"] + openai_params = [ + "functions", + "function_call", + "temperature", + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + ] + litellm_params = [ + "metadata", + "acompletion", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "input_cost_per_token", + "output_cost_per_token", + "hf_model_name", + "model_info", + "proxy_server_request", + "preset_cache_key", + "caching_groups", + ] default_params = openai_params + litellm_params - non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider if mock_response: - return mock_completion(model, messages, stream=stream, mock_response=mock_response) + return mock_completion( + model, messages, stream=stream, mock_response=mock_response + ) if timeout is None: - timeout = kwargs.get("request_timeout", None) or 600 # set timeout for 10 minutes by default + timeout = ( + kwargs.get("request_timeout", None) or 600 + ) # set timeout for 10 minutes by default timeout = float(timeout) try: - if base_url is not None: + if base_url is not None: api_base = base_url - if max_retries is not None: # openai allows openai.OpenAI(max_retries=3) + if max_retries is not None: # openai allows openai.OpenAI(max_retries=3) num_retries = max_retries logging = litellm_logging_obj - fallbacks = ( - fallbacks - or litellm.model_fallbacks - ) + fallbacks = fallbacks or litellm.model_fallbacks if fallbacks is not None: return completion_with_fallbacks(**args) - if model_list is not None: - deployments = [m["litellm_params"] for m in model_list if m["model_name"] == model] + if model_list is not None: + deployments = [ + m["litellm_params"] for m in model_list if m["model_name"] == model + ] return batch_completion_models(deployments=deployments, **args) if litellm.model_alias_map and model in litellm.model_alias_map: model = litellm.model_alias_map[ model ] # update the model to the actual value if an alias has been passed in model_response = ModelResponse() - if kwargs.get('azure', False) == True: # don't remove flag check, to remain backwards compatible for repos like Codium - custom_llm_provider="azure" - if deployment_id != None: # azure llms - model=deployment_id - custom_llm_provider="azure" - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key) + if ( + kwargs.get("azure", False) == True + ): # don't remove flag check, to remain backwards compatible for repos like Codium + custom_llm_provider = "azure" + if deployment_id != None: # azure llms + model = deployment_id + custom_llm_provider = "azure" + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + api_key=api_key, + ) ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### - if input_cost_per_token is not None and output_cost_per_token is not None: - litellm.register_model({ - model: { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider + if input_cost_per_token is not None and output_cost_per_token is not None: + litellm.register_model( + { + model: { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider, + } } - }) + ) ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### - custom_prompt_dict = {} # type: ignore - if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token: + custom_prompt_dict = {} # type: ignore + if ( + initial_prompt_value + or roles + or final_prompt_value + or bos_token + or eos_token + ): custom_prompt_dict = {model: {}} if initial_prompt_value: custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value - if roles: + if roles: custom_prompt_dict[model]["roles"] = roles - if final_prompt_value: + if final_prompt_value: custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value if bos_token: custom_prompt_dict[model]["bos_token"] = bos_token if eos_token: custom_prompt_dict[model]["eos_token"] = eos_token - model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model + model_api_key = get_api_key( + llm_provider=custom_llm_provider, dynamic_api_key=api_key + ) # get the api key from the environment if required for the model if model_api_key and "sk-litellm" in model_api_key: api_base = "https://proxy.litellm.ai" - custom_llm_provider = "openai" + custom_llm_provider = "openai" api_key = model_api_key - if dynamic_api_key is not None: - api_key = dynamic_api_key + if dynamic_api_key is not None: + api_key = dynamic_api_key # check if user passed in any of the OpenAI optional params optional_params = get_optional_params( - functions=functions, - function_call=function_call, - temperature=temperature, - top_p=top_p, - n=n, - stream=stream, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - user=user, - # params to identify the model - model=model, - custom_llm_provider=custom_llm_provider, - response_format=response_format, - seed=seed, - tools=tools, - tool_choice=tool_choice, - max_retries=max_retries, - **non_default_params + functions=functions, + function_call=function_call, + temperature=temperature, + top_p=top_p, + n=n, + stream=stream, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + user=user, + # params to identify the model + model=model, + custom_llm_provider=custom_llm_provider, + response_format=response_format, + seed=seed, + tools=tools, + tool_choice=tool_choice, + max_retries=max_retries, + **non_default_params, + ) + + if litellm.add_function_to_prompt and optional_params.get( + "functions_unsupported_model", None + ): # if user opts to add it to prompt, when API doesn't support function calling + functions_unsupported_model = optional_params.pop( + "functions_unsupported_model" + ) + messages = function_call_prompt( + messages=messages, functions=functions_unsupported_model ) - - if litellm.add_function_to_prompt and optional_params.get("functions_unsupported_model", None): # if user opts to add it to prompt, when API doesn't support function calling - functions_unsupported_model = optional_params.pop("functions_unsupported_model") - messages = function_call_prompt(messages=messages, functions=functions_unsupported_model) # For logging - save the values of the litellm-specific params passed in litellm_params = get_litellm_params( @@ -461,53 +589,50 @@ def completion( verbose=verbose, custom_llm_provider=custom_llm_provider, api_base=api_base, - litellm_call_id=kwargs.get('litellm_call_id', None), + litellm_call_id=kwargs.get("litellm_call_id", None), model_alias_map=litellm.model_alias_map, completion_call_id=id, metadata=metadata, model_info=model_info, proxy_server_request=proxy_server_request, - preset_cache_key=preset_cache_key + preset_cache_key=preset_cache_key, + ) + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params=litellm_params, ) - logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params) if custom_llm_provider == "azure": # azure configs api_type = get_secret("AZURE_API_TYPE") or "azure" - api_base = ( - api_base - or litellm.api_base - or get_secret("AZURE_API_BASE") - ) + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") api_version = ( - api_version or - litellm.api_version or - get_secret("AZURE_API_VERSION") + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") ) api_key = ( - api_key or - litellm.api_key or - litellm.azure_key or - get_secret("AZURE_OPENAI_API_KEY") or - get_secret("AZURE_API_KEY") + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") ) - azure_ad_token = ( - optional_params.pop("azure_ad_token", None) or - get_secret("AZURE_AD_TOKEN") + azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret( + "AZURE_AD_TOKEN" ) - headers = ( - headers or - litellm.headers - ) + headers = headers or litellm.headers ## LOAD CONFIG - if set - config=litellm.AzureOpenAIConfig.get_config() + config = litellm.AzureOpenAIConfig.get_config() for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in + if ( + k not in optional_params + ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v ## COMPLETION CALL @@ -525,10 +650,10 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, + logging_obj=logging, + acompletion=acompletion, timeout=timeout, - client=client # pass AsyncAzureOpenAI, AzureOpenAI client + client=client, # pass AsyncAzureOpenAI, AzureOpenAI client ) if optional_params.get("stream", False) or acompletion == True: @@ -556,7 +681,7 @@ def completion( # note: if a user sets a custom base - we should ensure this works # allow for the setting of dynamic and stateful api-bases api_base = ( - api_base # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api base from there + api_base # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api base from there or litellm.api_base or get_secret("OPENAI_API_BASE") or "https://api.openai.com/v1" @@ -564,25 +689,24 @@ def completion( openai.organization = ( litellm.organization or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 ) # set API KEY api_key = ( - api_key or # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there - litellm.api_key or - litellm.openai_key or - get_secret("OPENAI_API_KEY") + api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or get_secret("OPENAI_API_KEY") ) - headers = ( - headers or - litellm.headers - ) + headers = headers or litellm.headers ## LOAD CONFIG - if set - config=litellm.OpenAIConfig.get_config() + config = litellm.OpenAIConfig.get_config() for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in + if ( + k not in optional_params + ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v ## COMPLETION CALL @@ -601,7 +725,7 @@ def completion( logger_fn=logger_fn, timeout=timeout, custom_prompt_dict=custom_prompt_dict, - client=client # pass AsyncOpenAI, OpenAI client + client=client, # pass AsyncOpenAI, OpenAI client ) except Exception as e: ## LOGGING - log the original exception returned @@ -639,31 +763,34 @@ def completion( # set API KEY api_key = ( - api_key or - litellm.api_key or - litellm.openai_key or - get_secret("OPENAI_API_KEY") + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") ) - headers = ( - headers or - litellm.headers - ) + headers = headers or litellm.headers ## LOAD CONFIG - if set - config=litellm.OpenAITextCompletionConfig.get_config() + config = litellm.OpenAITextCompletionConfig.get_config() for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in + if ( + k not in optional_params + ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v if litellm.organization: openai.organization = litellm.organization - if len(messages)>0 and "content" in messages[0] and type(messages[0]["content"]) == list: + if ( + len(messages) > 0 + and "content" in messages[0] + and type(messages[0]["content"]) == list + ): # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] # https://platform.openai.com/docs/api-reference/completions/create prompt = messages[0]["content"] else: - prompt = " ".join([message["content"] for message in messages]) # type: ignore + prompt = " ".join([message["content"] for message in messages]) # type: ignore ## COMPLETION CALL model_response = openai_text_completions.completion( model=model, @@ -676,9 +803,9 @@ def completion( logging_obj=logging, optional_params=optional_params, litellm_params=litellm_params, - logger_fn=logger_fn + logger_fn=logger_fn, ) - + if optional_params.get("stream", False) or acompletion == True: ## LOGGING logging.post_call( @@ -689,16 +816,16 @@ def completion( ) response = model_response elif ( - "replicate" in model or - custom_llm_provider == "replicate" or - model in litellm.replicate_models + "replicate" in model + or custom_llm_provider == "replicate" + or model in litellm.replicate_models ): # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") replicate_key = None replicate_key = ( api_key or litellm.replicate_key - or litellm.api_key + or litellm.api_key or get_secret("REPLICATE_API_KEY") or get_secret("REPLICATE_API_TOKEN") ) @@ -710,10 +837,7 @@ def completion( or "https://api.replicate.com/v1" ) - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict - ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict model_response = replicate.completion( model=model, @@ -724,14 +848,14 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens + encoding=encoding, # for calculating input/output tokens api_key=replicate_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict, ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore + model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore if optional_params.get("stream", False) or acompletion == True: ## LOGGING @@ -743,12 +867,12 @@ def completion( response = model_response - elif custom_llm_provider=="anthropic": + elif custom_llm_provider == "anthropic": api_key = ( - api_key - or litellm.anthropic_key + api_key + or litellm.anthropic_key or litellm.api_key - or os.environ.get("ANTHROPIC_API_KEY") + or os.environ.get("ANTHROPIC_API_KEY") ) api_base = ( api_base @@ -756,10 +880,7 @@ def completion( or get_secret("ANTHROPIC_API_BASE") or "https://api.anthropic.com/v1/complete" ) - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict - ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict response = anthropic.completion( model=model, messages=messages, @@ -770,14 +891,19 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens + encoding=encoding, # for calculating input/output tokens api_key=api_key, - logging_obj=logging, + logging_obj=logging, ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(response, model, custom_llm_provider="anthropic", logging_obj=logging) - + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="anthropic", + logging_obj=logging, + ) + if optional_params.get("stream", False) or acompletion == True: ## LOGGING logging.post_call( @@ -788,7 +914,10 @@ def completion( response = response elif custom_llm_provider == "nlp_cloud": nlp_cloud_key = ( - api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key + api_key + or litellm.nlp_cloud_key + or get_secret("NLP_CLOUD_API_KEY") + or litellm.api_key ) api_base = ( @@ -809,13 +938,18 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=nlp_cloud_key, - logging_obj=logging + logging_obj=logging, ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(response, model, custom_llm_provider="nlp_cloud", logging_obj=logging) - + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="nlp_cloud", + logging_obj=logging, + ) + if optional_params.get("stream", False) or acompletion == True: ## LOGGING logging.post_call( @@ -827,7 +961,11 @@ def completion( response = response elif custom_llm_provider == "aleph_alpha": aleph_alpha_key = ( - api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key + api_key + or litellm.aleph_alpha_key + or get_secret("ALEPH_ALPHA_API_KEY") + or get_secret("ALEPHALPHA_API_KEY") + or litellm.api_key ) api_base = ( @@ -849,12 +987,17 @@ def completion( encoding=encoding, default_max_tokens_to_sample=litellm.max_tokens, api_key=aleph_alpha_key, - logging_obj=logging # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(model_response, model, custom_llm_provider="aleph_alpha", logging_obj=logging) + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="aleph_alpha", + logging_obj=logging, + ) return response response = model_response elif custom_llm_provider == "cohere": @@ -872,7 +1015,7 @@ def completion( or get_secret("COHERE_API_BASE") or "https://api.cohere.ai/v1/generate" ) - + model_response = cohere.completion( model=model, messages=messages, @@ -884,12 +1027,17 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=cohere_key, - logging_obj=logging # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(model_response, model, custom_llm_provider="cohere", logging_obj=logging) + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="cohere", + logging_obj=logging, + ) return response response = model_response elif custom_llm_provider == "maritalk": @@ -906,7 +1054,7 @@ def completion( or get_secret("MARITALK_API_BASE") or "https://chat.maritaca.ai/api/chat/inference" ) - + model_response = maritalk.completion( model=model, messages=messages, @@ -918,17 +1066,20 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=maritalk_key, - logging_obj=logging + logging_obj=logging, ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(model_response, model, custom_llm_provider="maritalk", logging_obj=logging) + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="maritalk", + logging_obj=logging, + ) return response response = model_response - elif ( - custom_llm_provider == "huggingface" - ): + elif custom_llm_provider == "huggingface": custom_llm_provider = "huggingface" huggingface_key = ( api_key @@ -937,35 +1088,36 @@ def completion( or os.environ.get("HUGGINGFACE_API_KEY") or litellm.api_key ) - hf_headers = ( - headers - or litellm.headers - ) + hf_headers = headers or litellm.headers - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict - ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict model_response = huggingface.completion( model=model, messages=messages, - api_base=api_base, # type: ignore + api_base=api_base, # type: ignore headers=hf_headers, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - encoding=encoding, - api_key=huggingface_key, + encoding=encoding, + api_key=huggingface_key, acompletion=acompletion, logging_obj=logging, - custom_prompt_dict=custom_prompt_dict + custom_prompt_dict=custom_prompt_dict, ) - if "stream" in optional_params and optional_params["stream"] == True and acompletion is False: + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion is False + ): # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="huggingface", logging_obj=logging + model_response, + model, + custom_llm_provider="huggingface", + logging_obj=logging, ) return response response = model_response @@ -975,73 +1127,62 @@ def completion( model=model, messages=messages, model_response=model_response, - api_base=api_base, # type: ignore + api_base=api_base, # type: ignore print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, api_key=None, logger_fn=logger_fn, encoding=encoding, - logging_obj=logging + logging_obj=logging, ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="oobabooga", logging_obj=logging + model_response, + model, + custom_llm_provider="oobabooga", + logging_obj=logging, ) return response response = model_response elif custom_llm_provider == "openrouter": - api_base = ( - api_base - or litellm.api_base - or "https://openrouter.ai/api/v1" - ) + api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1" api_key = ( - api_key or - litellm.api_key or - litellm.openrouter_key or - get_secret("OPENROUTER_API_KEY") or - get_secret("OR_API_KEY") + api_key + or litellm.api_key + or litellm.openrouter_key + or get_secret("OPENROUTER_API_KEY") + or get_secret("OR_API_KEY") ) - openrouter_site_url = ( - get_secret("OR_SITE_URL") - or "https://litellm.ai" - ) + openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai" - openrouter_app_name = ( - get_secret("OR_APP_NAME") - or "liteLLM" - ) + openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" headers = ( - headers or - litellm.headers or - { + headers + or litellm.headers + or { "HTTP-Referer": openrouter_site_url, "X-Title": openrouter_app_name, } ) ## Load Config - config = openrouter.OpenrouterConfig.get_config() - for k, v in config.items(): + config = openrouter.OpenrouterConfig.get_config() + for k, v in config.items(): if k == "extra_body": # we use openai 'extra_body' to pass openrouter specific params - transforms, route, models - if "extra_body" in optional_params: + if "extra_body" in optional_params: optional_params[k].update(v) else: optional_params[k] = v - elif k not in optional_params: + elif k not in optional_params: optional_params[k] = v - data = { - "model": model, - "messages": messages, - **optional_params - } + data = {"model": model, "messages": messages, **optional_params} ## COMPLETION CALL response = openai_chat_completions.completion( @@ -1055,15 +1196,19 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - logging_obj=logging, + logging_obj=logging, acompletion=acompletion, - timeout=timeout + timeout=timeout, ) ## LOGGING logging.post_call( input=messages, api_key=openai.api_key, original_response=response ) - elif custom_llm_provider == "together_ai" or ("togethercomputer" in model) or (model in litellm.together_ai_models): + elif ( + custom_llm_provider == "together_ai" + or ("togethercomputer" in model) + or (model in litellm.together_ai_models) + ): custom_llm_provider = "together_ai" together_ai_key = ( api_key @@ -1080,11 +1225,8 @@ def completion( or "https://api.together.xyz/inference" ) - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict - ) - + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + model_response = together_ai.completion( model=model, messages=messages, @@ -1097,22 +1239,24 @@ def completion( encoding=encoding, api_key=together_ai_key, logging_obj=logging, - custom_prompt_dict=custom_prompt_dict + custom_prompt_dict=custom_prompt_dict, ) - if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: + if ( + "stream_tokens" in optional_params + and optional_params["stream_tokens"] == True + ): # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="together_ai", logging_obj=logging + model_response, + model, + custom_llm_provider="together_ai", + logging_obj=logging, ) return response response = model_response elif custom_llm_provider == "palm": - palm_api_key = ( - api_key - or get_secret("PALM_API_KEY") - or litellm.api_key - ) - + palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key + # palm does not support streaming as yet :( model_response = palm.completion( model=model, @@ -1124,7 +1268,7 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=palm_api_key, - logging_obj=logging + logging_obj=logging, ) # fake palm streaming if "stream" in optional_params and optional_params["stream"] == True: @@ -1139,10 +1283,10 @@ def completion( gemini_api_key = ( api_key or get_secret("GEMINI_API_KEY") - or get_secret("PALM_API_KEY") # older palm api key should also work + or get_secret("PALM_API_KEY") # older palm api key should also work or litellm.api_key ) - + # palm does not support streaming as yet :( model_response = gemini.completion( model=model, @@ -1156,14 +1300,14 @@ def completion( api_key=gemini_api_key, logging_obj=logging, acompletion=acompletion, - custom_prompt_dict=custom_prompt_dict + custom_prompt_dict=custom_prompt_dict, ) response = model_response elif custom_llm_provider == "vertex_ai": - vertex_ai_project = (litellm.vertex_project - or get_secret("VERTEXAI_PROJECT")) - vertex_ai_location = (litellm.vertex_location - or get_secret("VERTEXAI_LOCATION")) + vertex_ai_project = litellm.vertex_project or get_secret("VERTEXAI_PROJECT") + vertex_ai_location = litellm.vertex_location or get_secret( + "VERTEXAI_LOCATION" + ) model_response = vertex_ai.completion( model=model, @@ -1176,14 +1320,21 @@ def completion( encoding=encoding, vertex_location=vertex_ai_location, vertex_project=vertex_ai_project, - logging_obj=logging, - acompletion=acompletion + logging_obj=logging, + acompletion=acompletion, ) - - if "stream" in optional_params and optional_params["stream"] == True and acompletion == False: + + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion == False + ): response = CustomStreamWrapper( - model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging - ) + model_response, + model, + custom_llm_provider="vertex_ai", + logging_obj=logging, + ) return response response = model_response elif custom_llm_provider == "ai21": @@ -1201,7 +1352,7 @@ def completion( or get_secret("AI21_API_BASE") or "https://api.ai21.com/studio/v1/" ) - + model_response = ai21.completion( model=model, messages=messages, @@ -1213,16 +1364,19 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=ai21_key, - logging_obj=logging + logging_obj=logging, ) - + if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="ai21", logging_obj=logging + model_response, + model, + custom_llm_provider="ai21", + logging_obj=logging, ) return response - + ## RESPONSE OBJECT response = model_response elif custom_llm_provider == "sagemaker": @@ -1238,18 +1392,23 @@ def completion( hf_model_name=hf_model_name, logger_fn=logger_fn, encoding=encoding, - logging_obj=logging + logging_obj=logging, ) - if "stream" in optional_params and optional_params["stream"]==True: ## [BETA] + if ( + "stream" in optional_params and optional_params["stream"] == True + ): ## [BETA] # sagemaker does not support streaming as of now so we're faking streaming: # https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611 # "SageMaker is currently not supporting streaming responses." - + # fake streaming for sagemaker print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") resp_string = model_response["choices"][0]["message"]["content"] response = CustomStreamWrapper( - resp_string, model, custom_llm_provider="sagemaker", logging_obj=logging + resp_string, + model, + custom_llm_provider="sagemaker", + logging_obj=logging, ) return response @@ -1257,10 +1416,7 @@ def completion( response = model_response elif custom_llm_provider == "bedrock": # boto3 reads keys from .env - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict - ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict response = bedrock.completion( model=model, messages=messages, @@ -1274,18 +1430,23 @@ def completion( logging_obj=logging, ) - if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - if "ai21" in model: + if "ai21" in model: response = CustomStreamWrapper( - response, model, custom_llm_provider="bedrock", logging_obj=logging + response, + model, + custom_llm_provider="bedrock", + logging_obj=logging, ) else: response = CustomStreamWrapper( - iter(response), model, custom_llm_provider="bedrock", logging_obj=logging + iter(response), + model, + custom_llm_provider="bedrock", + logging_obj=logging, ) - + if optional_params.get("stream", False): ## LOGGING logging.post_call( @@ -1294,7 +1455,6 @@ def completion( original_response=response, ) - ## RESPONSE OBJECT response = response elif custom_llm_provider == "vllm": @@ -1307,13 +1467,18 @@ def completion( litellm_params=litellm_params, logger_fn=logger_fn, encoding=encoding, - logging_obj=logging + logging_obj=logging, ) - if "stream" in optional_params and optional_params["stream"] == True: ## [BETA] + if ( + "stream" in optional_params and optional_params["stream"] == True + ): ## [BETA] # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="vllm", logging_obj=logging + model_response, + model, + custom_llm_provider="vllm", + logging_obj=logging, ) return response @@ -1321,27 +1486,27 @@ def completion( response = model_response elif custom_llm_provider == "ollama": api_base = ( - litellm.api_base or - api_base or - get_secret("OLLAMA_API_BASE") or - "http://localhost:11434" - - ) - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict + litellm.api_base + or api_base + or get_secret("OLLAMA_API_BASE") + or "http://localhost:11434" ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, ) else: - prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider) + prompt = prompt_factory( + model=model, + messages=messages, + custom_llm_provider=custom_llm_provider, + ) if isinstance(prompt, dict): # for multimode models - ollama/llava prompt_factory returns a dict { # "prompt": prompt, @@ -1351,10 +1516,19 @@ def completion( optional_params["images"] = images ## LOGGING - generator = ollama.get_ollama_response(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding) + generator = ollama.get_ollama_response( + api_base, + model, + prompt, + optional_params, + logging_obj=logging, + acompletion=acompletion, + model_response=model_response, + encoding=encoding, + ) if acompletion is True or optional_params.get("stream", False) == True: return generator - + response = generator elif ( custom_llm_provider == "baseten" @@ -1362,7 +1536,10 @@ def completion( ): custom_llm_provider = "baseten" baseten_key = ( - api_key or litellm.baseten_key or os.environ.get("BASETEN_API_KEY") or litellm.api_key + api_key + or litellm.baseten_key + or os.environ.get("BASETEN_API_KEY") + or litellm.api_key ) model_response = baseten.completion( @@ -1373,25 +1550,24 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - encoding=encoding, - api_key=baseten_key, - logging_obj=logging + encoding=encoding, + api_key=baseten_key, + logging_obj=logging, ) - if inspect.isgenerator(model_response) or ("stream" in optional_params and optional_params["stream"] == True): + if inspect.isgenerator(model_response) or ( + "stream" in optional_params and optional_params["stream"] == True + ): # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="baseten", logging_obj=logging + model_response, + model, + custom_llm_provider="baseten", + logging_obj=logging, ) return response response = model_response - elif ( - custom_llm_provider == "petals" - or model in litellm.petals_models - ): - api_base = ( - api_base or - litellm.api_base - ) + elif custom_llm_provider == "petals" or model in litellm.petals_models: + api_base = api_base or litellm.api_base custom_llm_provider = "petals" stream = optional_params.pop("stream", False) @@ -1404,29 +1580,28 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging + encoding=encoding, + logging_obj=logging, ) - if stream==True: ## [BETA] + if stream == True: ## [BETA] # Fake streaming for petals resp_string = model_response["choices"][0]["message"]["content"] response = CustomStreamWrapper( - resp_string, model, custom_llm_provider="petals", logging_obj=logging + resp_string, + model, + custom_llm_provider="petals", + logging_obj=logging, ) return response response = model_response - elif ( - custom_llm_provider == "custom" - ): + elif custom_llm_provider == "custom": import requests - url = ( - litellm.api_base or - api_base or - "" - ) + url = litellm.api_base or api_base or "" if url == None or url == "": - raise ValueError("api_base not set. Set api_base or litellm.api_base for custom endpoints") + raise ValueError( + "api_base not set. Set api_base or litellm.api_base for custom endpoints" + ) """ assume input to custom LLM api bases follow this format: @@ -1445,17 +1620,20 @@ def completion( ) """ - prompt = " ".join([message["content"] for message in messages]) # type: ignore - resp = requests.post(url, json={ - 'model': model, - 'params': { - 'prompt': [prompt], - 'max_tokens': max_tokens, - 'temperature': temperature, - 'top_p': top_p, - 'top_k': kwargs.get('top_k', 40), - } - }) + prompt = " ".join([message["content"] for message in messages]) # type: ignore + resp = requests.post( + url, + json={ + "model": model, + "params": { + "prompt": [prompt], + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "top_k": kwargs.get("top_k", 40), + }, + }, + ) response_json = resp.json() """ assume all responses from custom api_bases of this format: @@ -1470,7 +1648,7 @@ def completion( ] } """ - string_response = response_json['data'][0]['output'][0] + string_response = response_json["data"][0]["output"][0] ## RESPONSE OBJECT model_response["choices"][0]["message"]["content"] = string_response model_response["created"] = int(time.time()) @@ -1484,8 +1662,11 @@ def completion( except Exception as e: ## Map to OpenAI Exception raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, - ) + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) def completion_with_retries(*args, **kwargs): @@ -1495,17 +1676,26 @@ def completion_with_retries(*args, **kwargs): try: import tenacity except Exception as e: - raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}") - + raise Exception( + f"tenacity import failed please run `pip install tenacity`. Error{e}" + ) + num_retries = kwargs.pop("num_retries", 3) retry_strategy = kwargs.pop("retry_strategy", "constant_retry") original_function = kwargs.pop("original_function", completion) - if retry_strategy == "constant_retry": - retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True) - elif retry_strategy == "exponential_backoff_retry": - retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True) + if retry_strategy == "constant_retry": + retryer = tenacity.Retrying( + stop=tenacity.stop_after_attempt(num_retries), reraise=True + ) + elif retry_strategy == "exponential_backoff_retry": + retryer = tenacity.Retrying( + wait=tenacity.wait_exponential(multiplier=1, max=10), + stop=tenacity.stop_after_attempt(num_retries), + reraise=True, + ) return retryer(original_function, *args, **kwargs) + async def acompletion_with_retries(*args, **kwargs): """ Executes a litellm.completion() with 3 retries @@ -1513,19 +1703,26 @@ async def acompletion_with_retries(*args, **kwargs): try: import tenacity except Exception as e: - raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}") - + raise Exception( + f"tenacity import failed please run `pip install tenacity`. Error{e}" + ) + num_retries = kwargs.pop("num_retries", 3) retry_strategy = kwargs.pop("retry_strategy", "constant_retry") original_function = kwargs.pop("original_function", completion) - if retry_strategy == "constant_retry": - retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True) - elif retry_strategy == "exponential_backoff_retry": - retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True) + if retry_strategy == "constant_retry": + retryer = tenacity.Retrying( + stop=tenacity.stop_after_attempt(num_retries), reraise=True + ) + elif retry_strategy == "exponential_backoff_retry": + retryer = tenacity.Retrying( + wait=tenacity.wait_exponential(multiplier=1, max=10), + stop=tenacity.stop_after_attempt(num_retries), + reraise=True, + ) return await retryer(original_function, *args, **kwargs) - def batch_completion( model: str, # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create @@ -1539,13 +1736,14 @@ def batch_completion( stop=None, max_tokens: Optional[float] = None, presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float]=None, + frequency_penalty: Optional[float] = None, logit_bias: Optional[dict] = None, user: Optional[str] = None, - deployment_id = None, + deployment_id=None, request_timeout: Optional[int] = None, # Optional liteLLM function params - **kwargs): + **kwargs, +): """ Batch litellm.completion function for a given model. @@ -1594,15 +1792,22 @@ def batch_completion( user=user, # params to identify the model model=model, - custom_llm_provider=custom_llm_provider + custom_llm_provider=custom_llm_provider, ) - results = vllm.batch_completions(model=model, messages=batch_messages, custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params) - # all non VLLM models for batch completion models + results = vllm.batch_completions( + model=model, + messages=batch_messages, + custom_prompt_dict=litellm.custom_prompt_dict, + optional_params=optional_params, + ) + # all non VLLM models for batch completion models else: + def chunks(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): - yield lst[i:i + n] + yield lst[i : i + n] + with ThreadPoolExecutor(max_workers=100) as executor: for sub_batch in chunks(batch_messages, 100): for message_list in sub_batch: @@ -1611,13 +1816,16 @@ def batch_completion( original_kwargs = {} if "kwargs" in kwargs_modified: original_kwargs = kwargs_modified.pop("kwargs") - future = executor.submit(completion, **kwargs_modified, **original_kwargs) + future = executor.submit( + completion, **kwargs_modified, **original_kwargs + ) completions.append(future) # Retrieve the results from the futures results = [future.result() for future in completions] return results + # send one request to multiple models # return as soon as one of the llms responds def batch_completion_models(*args, **kwargs): @@ -1639,6 +1847,7 @@ def batch_completion_models(*args, **kwargs): It sends requests concurrently and returns the response from the first model that responds. """ import concurrent + if "model" in kwargs: kwargs.pop("model") if "models" in kwargs: @@ -1647,21 +1856,29 @@ def batch_completion_models(*args, **kwargs): futures = {} with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: for model in models: - futures[model] = executor.submit(completion, *args, model=model, **kwargs) + futures[model] = executor.submit( + completion, *args, model=model, **kwargs + ) - for model, future in sorted(futures.items(), key=lambda x: models.index(x[0])): + for model, future in sorted( + futures.items(), key=lambda x: models.index(x[0]) + ): if future.result() is not None: return future.result() - elif "deployments" in kwargs: + elif "deployments" in kwargs: deployments = kwargs["deployments"] kwargs.pop("deployments") kwargs.pop("model_list") nested_kwargs = kwargs.pop("kwargs", {}) futures = {} - with concurrent.futures.ThreadPoolExecutor(max_workers=len(deployments)) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(deployments) + ) as executor: for deployment in deployments: - for key in kwargs.keys(): - if key not in deployment: # don't override deployment values e.g. model name, api base, etc. + for key in kwargs.keys(): + if ( + key not in deployment + ): # don't override deployment values e.g. model name, api base, etc. deployment[key] = kwargs[key] kwargs = {**deployment, **nested_kwargs} futures[deployment["model"]] = executor.submit(completion, **kwargs) @@ -1669,7 +1886,9 @@ def batch_completion_models(*args, **kwargs): while futures: # wait for the first returned future print_verbose("\n\n waiting for next result\n\n") - done, _ = concurrent.futures.wait(futures.values(), return_when=concurrent.futures.FIRST_COMPLETED) + done, _ = concurrent.futures.wait( + futures.values(), return_when=concurrent.futures.FIRST_COMPLETED + ) print_verbose(f"done list\n{done}") for future in done: try: @@ -1677,7 +1896,9 @@ def batch_completion_models(*args, **kwargs): return result except Exception as e: # if model 1 fails, continue with response from model 2, model3 - print_verbose(f"\n\ngot an exception, ignoring, removing from futures") + print_verbose( + f"\n\ngot an exception, ignoring, removing from futures" + ) print_verbose(futures) new_futures = {} for key, value in futures.items(): @@ -1690,12 +1911,12 @@ def batch_completion_models(*args, **kwargs): print_verbose(f"new futures{futures}") continue - print_verbose("\n\ndone looping through futures\n\n") print_verbose(futures) return None # If no response is received from any model + def batch_completion_models_all_responses(*args, **kwargs): """ Send a request to multiple language models concurrently and return a list of responses @@ -1737,6 +1958,7 @@ def batch_completion_models_all_responses(*args, **kwargs): return responses + ### EMBEDDING ENDPOINTS #################### @client async def aembedding(*args, **kwargs): @@ -1752,10 +1974,10 @@ async def aembedding(*args, **kwargs): """ loop = asyncio.get_event_loop() model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO Embedding ### + ### PASS ARGS TO Embedding ### kwargs["aembedding"] = True custom_llm_provider = None - try: + try: # Use a partial function to pass your keyword arguments func = partial(embedding, *args, **kwargs) @@ -1763,50 +1985,60 @@ async def aembedding(*args, **kwargs): ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) - if (custom_llm_provider == "openai" - or custom_llm_provider == "azure" + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "azure" or custom_llm_provider == "custom_openai" or custom_llm_provider == "anyscale" or custom_llm_provider == "openrouter" or custom_llm_provider == "deepinfra" or custom_llm_provider == "perplexity" - or custom_llm_provider == "ollama"): # currently implemented aiohttp calls for just azure and openai, soon all. + or custom_llm_provider == "ollama" + ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO + if isinstance(init_response, dict) or isinstance( + init_response, ModelResponse + ): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): response = await init_response - else: + else: # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) + response = await loop.run_in_executor(None, func_with_context) return response - except Exception as e: + except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, - ) + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) + @client def embedding( - model, - input=[], + model, + input=[], # Optional params - timeout=600, # default to 10 minutes + timeout=600, # default to 10 minutes # set api_base, api_version, api_key api_base: Optional[str] = None, api_version: Optional[str] = None, api_key: Optional[str] = None, api_type: Optional[str] = None, - caching: bool=False, - user: Optional[str]=None, + caching: bool = False, + user: Optional[str] = None, custom_llm_provider=None, - litellm_call_id=None, + litellm_call_id=None, litellm_logging_obj=None, - logger_fn=None, - **kwargs + logger_fn=None, + **kwargs, ): """ Embedding function that calls an API to generate embeddings for the given input. @@ -1840,43 +2072,116 @@ def embedding( encoding_format = kwargs.get("encoding_format", None) proxy_server_request = kwargs.get("proxy_server_request", None) aembedding = kwargs.get("aembedding", None) - openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"] - litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"] + openai_params = [ + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "max_retries", + "encoding_format", + ] + litellm_params = [ + "metadata", + "aembedding", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "input_cost_per_token", + "output_cost_per_token", + "hf_model_name", + "proxy_server_request", + "model_info", + "preset_cache_key", + "caching_groups", + ] default_params = openai_params + litellm_params - non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider - - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key) - optional_params = get_optional_params_embeddings(user=user, encoding_format=encoding_format, custom_llm_provider=custom_llm_provider, **non_default_params) + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider + + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + api_key=api_key, + ) + optional_params = get_optional_params_embeddings( + user=user, + encoding_format=encoding_format, + custom_llm_provider=custom_llm_provider, + **non_default_params, + ) try: response = None logging = litellm_logging_obj - logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding, "preset_cache_key": None, "stream_response": {}}) + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params={ + "timeout": timeout, + "azure": azure, + "litellm_call_id": litellm_call_id, + "logger_fn": logger_fn, + "proxy_server_request": proxy_server_request, + "model_info": model_info, + "metadata": metadata, + "aembedding": aembedding, + "preset_cache_key": None, + "stream_response": {}, + }, + ) if azure == True or custom_llm_provider == "azure": # azure configs api_type = get_secret("AZURE_API_TYPE") or "azure" - api_base = ( - api_base - or litellm.api_base - or get_secret("AZURE_API_BASE") - ) + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") api_version = ( - api_version or - litellm.api_version or - get_secret("AZURE_API_VERSION") + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") ) - azure_ad_token = ( - kwargs.pop("azure_ad_token", None) or - get_secret("AZURE_AD_TOKEN") + azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret( + "AZURE_AD_TOKEN" ) api_key = ( - api_key or - litellm.api_key or - litellm.azure_key or - get_secret("AZURE_API_KEY") + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_API_KEY") ) ## EMBEDDING CALL response = azure_chat_completions.embedding( @@ -1888,12 +2193,14 @@ def embedding( azure_ad_token=azure_ad_token, logging_obj=logging, timeout=timeout, - model_response=EmbeddingResponse(), + model_response=EmbeddingResponse(), optional_params=optional_params, client=client, - aembedding=aembedding + aembedding=aembedding, ) - elif model in litellm.open_ai_embedding_models or custom_llm_provider == "openai": + elif ( + model in litellm.open_ai_embedding_models or custom_llm_provider == "openai" + ): api_base = ( api_base or litellm.api_base @@ -1903,19 +2210,18 @@ def embedding( openai.organization = ( litellm.organization or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 ) # set API KEY api_key = ( - api_key or - litellm.api_key or - litellm.openai_key or - get_secret("OPENAI_API_KEY") + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") ) api_type = "openai" api_version = None - ## EMBEDDING CALL response = openai_chat_completions.embedding( model=model, @@ -1924,7 +2230,7 @@ def embedding( api_key=api_key, logging_obj=logging, timeout=timeout, - model_response=EmbeddingResponse(), + model_response=EmbeddingResponse(), optional_params=optional_params, client=client, aembedding=aembedding, @@ -1944,8 +2250,7 @@ def embedding( encoding=encoding, api_key=cohere_key, logging_obj=logging, - model_response= EmbeddingResponse() - + model_response=EmbeddingResponse(), ) elif custom_llm_provider == "huggingface": api_key = ( @@ -1961,7 +2266,7 @@ def embedding( api_key=api_key, api_base=api_base, logging_obj=logging, - model_response= EmbeddingResponse() + model_response=EmbeddingResponse(), ) elif custom_llm_provider == "bedrock": response = bedrock.embedding( @@ -1970,9 +2275,9 @@ def embedding( encoding=encoding, logging_obj=logging, optional_params=optional_params, - model_response= EmbeddingResponse() + model_response=EmbeddingResponse(), ) - elif custom_llm_provider == "ollama": + elif custom_llm_provider == "ollama": if aembedding == True: response = ollama.ollama_aembeddings( model=model, @@ -1982,15 +2287,15 @@ def embedding( optional_params=optional_params, model_response=EmbeddingResponse(), ) - elif custom_llm_provider == "sagemaker": + elif custom_llm_provider == "sagemaker": response = sagemaker.embedding( model=model, input=input, encoding=encoding, logging_obj=logging, optional_params=optional_params, - model_response= EmbeddingResponse(), - print_verbose=print_verbose + model_response=EmbeddingResponse(), + print_verbose=print_verbose, ) else: args = locals() @@ -2014,14 +2319,14 @@ def embedding( ###### Text Completion ################ async def atext_completion(*args, **kwargs): """ - Implemented to handle async streaming for the text completion endpoint + Implemented to handle async streaming for the text completion endpoint """ loop = asyncio.get_event_loop() model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO COMPLETION ### + ### PASS ARGS TO COMPLETION ### kwargs["acompletion"] = True custom_llm_provider = None - try: + try: # Use a partial function to pass your keyword arguments func = partial(text_completion, *args, **kwargs) @@ -2029,10 +2334,13 @@ async def atext_completion(*args, **kwargs): ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) - if (custom_llm_provider == "openai" - or custom_llm_provider == "azure" + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "azure" or custom_llm_provider == "custom_openai" or custom_llm_provider == "anyscale" or custom_llm_provider == "mistral" @@ -2042,58 +2350,91 @@ async def atext_completion(*args, **kwargs): or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "huggingface" or custom_llm_provider == "ollama" - or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all. - if kwargs.get("stream", False): + or custom_llm_provider == "vertex_ai" + ): # currently implemented aiohttp calls for just azure and openai, soon all. + if kwargs.get("stream", False): response = text_completion(*args, **kwargs) else: # Await normally response = await loop.run_in_executor(None, func_with_context) if asyncio.iscoroutine(response): response = await response - else: + else: # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) - if kwargs.get("stream", False): # return an async generator - return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args) - else: + response = await loop.run_in_executor(None, func_with_context) + if kwargs.get("stream", False): # return an async generator + return _async_streaming( + response=response, + model=model, + custom_llm_provider=custom_llm_provider, + args=args, + ) + else: return response - except Exception as e: + except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, - ) + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) + def text_completion( - prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for. - model: Optional[str]=None, # Optional: either `model` or `engine` can be set - best_of: Optional[int] = None, # Optional: Generates best_of completions server-side. - echo: Optional[bool] = None, # Optional: Echo back the prompt in addition to the completion. - frequency_penalty: Optional[float] = None, # Optional: Penalize new tokens based on their existing frequency. - logit_bias: Optional[Dict[int, int]] = None, # Optional: Modify the likelihood of specified tokens. - logprobs: Optional[int] = None, # Optional: Include the log probabilities on the most likely tokens. - max_tokens: Optional[int] = None, # Optional: The maximum number of tokens to generate in the completion. - n: Optional[int] = None, # Optional: How many completions to generate for each prompt. - presence_penalty: Optional[float] = None, # Optional: Penalize new tokens based on whether they appear in the text so far. - stop: Optional[Union[str, List[str]]] = None, # Optional: Sequences where the API will stop generating further tokens. - stream: Optional[bool] = None, # Optional: Whether to stream back partial progress. - suffix: Optional[str] = None, # Optional: The suffix that comes after a completion of inserted text. - temperature: Optional[float] = None, # Optional: Sampling temperature to use. - top_p: Optional[float] = None, # Optional: Nucleus sampling parameter. - user: Optional[str] = None, # Optional: A unique identifier representing your end-user. - + prompt: Union[ + str, List[Union[str, List[Union[str, List[int]]]]] + ], # Required: The prompt(s) to generate completions for. + model: Optional[str] = None, # Optional: either `model` or `engine` can be set + best_of: Optional[ + int + ] = None, # Optional: Generates best_of completions server-side. + echo: Optional[ + bool + ] = None, # Optional: Echo back the prompt in addition to the completion. + frequency_penalty: Optional[ + float + ] = None, # Optional: Penalize new tokens based on their existing frequency. + logit_bias: Optional[ + Dict[int, int] + ] = None, # Optional: Modify the likelihood of specified tokens. + logprobs: Optional[ + int + ] = None, # Optional: Include the log probabilities on the most likely tokens. + max_tokens: Optional[ + int + ] = None, # Optional: The maximum number of tokens to generate in the completion. + n: Optional[ + int + ] = None, # Optional: How many completions to generate for each prompt. + presence_penalty: Optional[ + float + ] = None, # Optional: Penalize new tokens based on whether they appear in the text so far. + stop: Optional[ + Union[str, List[str]] + ] = None, # Optional: Sequences where the API will stop generating further tokens. + stream: Optional[bool] = None, # Optional: Whether to stream back partial progress. + suffix: Optional[ + str + ] = None, # Optional: The suffix that comes after a completion of inserted text. + temperature: Optional[float] = None, # Optional: Sampling temperature to use. + top_p: Optional[float] = None, # Optional: Nucleus sampling parameter. + user: Optional[ + str + ] = None, # Optional: A unique identifier representing your end-user. # set api_base, api_version, api_key api_base: Optional[str] = None, api_version: Optional[str] = None, api_key: Optional[str] = None, - model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. - + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. # Optional liteLLM function params custom_llm_provider: Optional[str] = None, - *args, - **kwargs + *args, + **kwargs, ): global print_verbose import copy + """ Generate text completions using the OpenAI API. @@ -2120,8 +2461,8 @@ def text_completion( Example: Your example of how to use this function goes here. """ - if "engine" in kwargs: - if model==None: + if "engine" in kwargs: + if model == None: # only use engine when model not passed model = kwargs["engine"] kwargs.pop("engine") @@ -2168,7 +2509,7 @@ def text_completion( optional_params["custom_llm_provider"] = custom_llm_provider # get custom_llm_provider - _, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + _, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore if custom_llm_provider == "huggingface": # if echo == True, for TGI llms we need to set top_n_tokens to 3 @@ -2180,17 +2521,19 @@ def text_completion( # processing prompt - users can pass raw tokens to OpenAI Completion() if type(prompt) == list: import concurrent.futures + tokenizer = tiktoken.encoding_for_model("text-davinci-003") ## if it's a 2d list - each element in the list is a text_completion() request if len(prompt) > 0 and type(prompt[0]) == list: - responses = [None for x in prompt] # init responses + responses = [None for x in prompt] # init responses + def process_prompt(i, individual_prompt): decoded_prompt = tokenizer.decode(individual_prompt) all_params = {**kwargs, **optional_params} response = text_completion( model=model, prompt=decoded_prompt, - num_retries=3,# ensure this does not fail for the batch + num_retries=3, # ensure this does not fail for the batch *args, **all_params, ) @@ -2200,22 +2543,28 @@ def text_completion( text_completion_response["created"] = response.get("created", None) text_completion_response["model"] = response.get("model", None) return response["choices"][0] + with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(process_prompt, i, individual_prompt) for i, individual_prompt in enumerate(prompt)] - for i, future in enumerate(concurrent.futures.as_completed(futures)): + futures = [ + executor.submit(process_prompt, i, individual_prompt) + for i, individual_prompt in enumerate(prompt) + ] + for i, future in enumerate( + concurrent.futures.as_completed(futures) + ): responses[i] = future.result() - text_completion_response.choices = responses + text_completion_response.choices = responses return text_completion_response # else: - # check if non default values passed in for best_of, echo, logprobs, suffix + # check if non default values passed in for best_of, echo, logprobs, suffix # these are the params supported by Completion() but not ChatCompletion - + # default case, non OpenAI requests go through here messages = [{"role": "system", "content": prompt}] kwargs.pop("prompt", None) response = completion( - model = model, + model=model, messages=messages, *args, **kwargs, @@ -2224,7 +2573,7 @@ def text_completion( if stream == True or kwargs.get("stream", False) == True: response = TextCompletionStreamWrapper(completion_stream=response, model=model) return response - if kwargs.get("acompletion", False) == True: + if kwargs.get("acompletion", False) == True: return response transformed_logprobs = None # only supported for TGI models @@ -2246,22 +2595,21 @@ def text_completion( text_completion_response["usage"] = response.get("usage", None) return text_completion_response + ##### Moderation ####################### -def moderation(input: str, api_key: Optional[str]=None): +def moderation(input: str, api_key: Optional[str] = None): # only supports open ai for now api_key = ( - api_key or - litellm.api_key or - litellm.openai_key or - get_secret("OPENAI_API_KEY") - ) + api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") + ) openai.api_key = api_key - openai.api_type = "open_ai" # type: ignore + openai.api_type = "open_ai" # type: ignore openai.api_version = None openai.base_url = "https://api.openai.com/v1/" response = openai.moderations.create(input=input) return response + ##### Image Generation ####################### @client async def aimage_generation(*args, **kwargs): @@ -2277,10 +2625,10 @@ async def aimage_generation(*args, **kwargs): """ loop = asyncio.get_event_loop() model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO Image Generation ### + ### PASS ARGS TO Image Generation ### kwargs["aimg_generation"] = True custom_llm_provider = None - try: + try: # Use a partial function to pass your keyword arguments func = partial(image_generation, *args, **kwargs) @@ -2288,117 +2636,217 @@ async def aimage_generation(*args, **kwargs): ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) - + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) + # Await normally init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO + if isinstance(init_response, dict) or isinstance( + init_response, ModelResponse + ): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): response = await init_response - else: + else: # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) + response = await loop.run_in_executor(None, func_with_context) return response - except Exception as e: + except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, - ) - -@client -def image_generation(prompt: str, - model: Optional[str]=None, - n: Optional[int]=None, - quality: Optional[str]=None, - response_format: Optional[str]=None, - size: Optional[str]=None, - style: Optional[str]=None, - user: Optional[str]=None, - timeout=600, # default to 10 minutes - api_key: Optional[str]=None, - api_base: Optional[str]=None, - api_version: Optional[str] = None, - litellm_logging_obj=None, - custom_llm_provider=None, - **kwargs): - """ - Maps the https://api.openai.com/v1/images/generations endpoint. + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) - Currently supports just Azure + OpenAI. + +@client +def image_generation( + prompt: str, + model: Optional[str] = None, + n: Optional[int] = None, + quality: Optional[str] = None, + response_format: Optional[str] = None, + size: Optional[str] = None, + style: Optional[str] = None, + user: Optional[str] = None, + timeout=600, # default to 10 minutes + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + litellm_logging_obj=None, + custom_llm_provider=None, + **kwargs, +): + """ + Maps the https://api.openai.com/v1/images/generations endpoint. + + Currently supports just Azure + OpenAI. """ aimg_generation = kwargs.get("aimg_generation", False) litellm_call_id = kwargs.get("litellm_call_id", None) logger_fn = kwargs.get("logger_fn", None) - proxy_server_request = kwargs.get('proxy_server_request', None) + proxy_server_request = kwargs.get("proxy_server_request", None) model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", {}) model_response = litellm.utils.ImageResponse() - if model is not None or custom_llm_provider is not None: - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore - else: + if model is not None or custom_llm_provider is not None: + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + else: model = "dall-e-2" - custom_llm_provider = "openai" # default to dall-e-2 on openai - openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"] - litellm_params = ["metadata", "aimg_generation", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"] + custom_llm_provider = "openai" # default to dall-e-2 on openai + openai_params = [ + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "max_retries", + "n", + "quality", + "size", + "style", + ] + litellm_params = [ + "metadata", + "aimg_generation", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "input_cost_per_token", + "output_cost_per_token", + "hf_model_name", + "proxy_server_request", + "model_info", + "preset_cache_key", + "caching_groups", + ] default_params = openai_params + litellm_params - non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider - optional_params = get_optional_params_image_gen(n=n, - quality=quality, - response_format=response_format, - size=size, - style=style, - user=user, - custom_llm_provider=custom_llm_provider, - **non_default_params) + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider + optional_params = get_optional_params_image_gen( + n=n, + quality=quality, + response_format=response_format, + size=size, + style=style, + user=user, + custom_llm_provider=custom_llm_provider, + **non_default_params, + ) logging = litellm_logging_obj - logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": False, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "preset_cache_key": None, "stream_response": {}}) + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params={ + "timeout": timeout, + "azure": False, + "litellm_call_id": litellm_call_id, + "logger_fn": logger_fn, + "proxy_server_request": proxy_server_request, + "model_info": model_info, + "metadata": metadata, + "preset_cache_key": None, + "stream_response": {}, + }, + ) if custom_llm_provider == "azure": # azure configs api_type = get_secret("AZURE_API_TYPE") or "azure" - api_base = ( - api_base - or litellm.api_base - or get_secret("AZURE_API_BASE") - ) + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") api_version = ( - api_version or - litellm.api_version or - get_secret("AZURE_API_VERSION") + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") ) api_key = ( - api_key or - litellm.api_key or - litellm.azure_key or - get_secret("AZURE_OPENAI_API_KEY") or - get_secret("AZURE_API_KEY") + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") ) - azure_ad_token = ( - optional_params.pop("azure_ad_token", None) or - get_secret("AZURE_AD_TOKEN") + azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret( + "AZURE_AD_TOKEN" ) - model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimg_generation) + model_response = azure_chat_completions.image_generation( + model=model, + prompt=prompt, + timeout=timeout, + api_key=api_key, + api_base=api_base, + logging_obj=litellm_logging_obj, + optional_params=optional_params, + model_response=model_response, + api_version=api_version, + aimg_generation=aimg_generation, + ) elif custom_llm_provider == "openai": - model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimg_generation) + model_response = openai_chat_completions.image_generation( + model=model, + prompt=prompt, + timeout=timeout, + api_key=api_key, + api_base=api_base, + logging_obj=litellm_logging_obj, + optional_params=optional_params, + model_response=model_response, + aimg_generation=aimg_generation, + ) return model_response + ####### HELPER FUNCTIONS ################ ## Set verbose to true -> ```litellm.set_verbose = True``` def print_verbose(print_statement): try: if litellm.set_verbose: - print(print_statement) # noqa + print(print_statement) # noqa except: pass + def config_completion(**kwargs): if litellm.config_path != None: config_args = read_config_args(litellm.config_path) @@ -2409,7 +2857,8 @@ def config_completion(**kwargs): "No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`" ) -def stream_chunk_builder(chunks: list, messages: Optional[list]=None): + +def stream_chunk_builder(chunks: list, messages: Optional[list] = None): id = chunks[0]["id"] object = chunks[0]["object"] created = chunks[0]["created"] @@ -2428,18 +2877,15 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): "choices": [ { "index": 0, - "message": { - "role": role, - "content": "" - }, + "message": {"role": role, "content": ""}, "finish_reason": finish_reason, } ], "usage": { "prompt_tokens": 0, # Modify as needed "completion_tokens": 0, # Modify as needed - "total_tokens": 0 # Modify as needed - } + "total_tokens": 0, # Modify as needed + }, } # Extract the "content" strings from the nested dictionaries within "choices" @@ -2447,7 +2893,10 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): combined_content = "" combined_arguments = "" - if "tool_calls" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["tool_calls"] is not None: + if ( + "tool_calls" in chunks[0]["choices"][0]["delta"] + and chunks[0]["choices"][0]["delta"]["tool_calls"] is not None + ): argument_list = [] delta = chunks[0]["choices"][0]["delta"] message = response["choices"][0]["message"] @@ -2478,22 +2927,38 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): # Now, tool_calls is expected to be a dictionary arguments = tool_calls[0].function.arguments argument_list.append(arguments) - if tool_calls[0].function.name: + if tool_calls[0].function.name: name = tool_calls[0].function.name - if tool_calls[0].type: + if tool_calls[0].type: type = tool_calls[0].type - if curr_index != prev_index: # new tool call + if curr_index != prev_index: # new tool call combined_arguments = "".join(argument_list) - tool_calls_list.append({"id": prev_id, "index": prev_index, "function": {"arguments": combined_arguments, "name": name}, "type": type}) - argument_list = [] # reset + tool_calls_list.append( + { + "id": prev_id, + "index": prev_index, + "function": {"arguments": combined_arguments, "name": name}, + "type": type, + } + ) + argument_list = [] # reset prev_index = curr_index prev_id = curr_id combined_arguments = "".join(argument_list) - tool_calls_list.append({"id": id, "function": {"arguments": combined_arguments, "name": name}, "type": type}) - response["choices"][0]["message"]["content"] = None + tool_calls_list.append( + { + "id": id, + "function": {"arguments": combined_arguments, "name": name}, + "type": type, + } + ) + response["choices"][0]["message"]["content"] = None response["choices"][0]["message"]["tool_calls"] = tool_calls_list - elif "function_call" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["function_call"] is not None: + elif ( + "function_call" in chunks[0]["choices"][0]["delta"] + and chunks[0]["choices"][0]["delta"]["function_call"] is not None + ): argument_list = [] delta = chunks[0]["choices"][0]["delta"] function_call = delta.get("function_call", "") @@ -2508,7 +2973,7 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): for choice in choices: delta = choice.get("delta", {}) function_call = delta.get("function_call", "") - + # Check if a function call is present if function_call: # Now, function_call is expected to be a dictionary @@ -2517,7 +2982,9 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): combined_arguments = "".join(argument_list) response["choices"][0]["message"]["content"] = None - response["choices"][0]["message"]["function_call"]["arguments"] = combined_arguments + response["choices"][0]["message"]["function_call"][ + "arguments" + ] = combined_arguments else: for chunk in chunks: choices = chunk["choices"] @@ -2525,7 +2992,7 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): delta = choice.get("delta", {}) content = delta.get("content", "") if content == None: - continue # openai v1.0.0 sets content = None for chunks + continue # openai v1.0.0 sets content = None for chunks content_list.append(content) # Combine the "content" strings into a single string || combine the 'function' strings into a single string @@ -2533,19 +3000,27 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): # Update the "content" field within the response dictionary response["choices"][0]["message"]["content"] = combined_content - + if len(combined_content) > 0: completion_output = combined_content - elif len(combined_arguments) > 0: + elif len(combined_arguments) > 0: completion_output = combined_arguments - else: + else: completion_output = "" # # Update usage information if needed try: - response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages) - except: # don't allow this failing to block a complete streaming response from being returned + response["usage"]["prompt_tokens"] = token_counter( + model=model, messages=messages + ) + except: # don't allow this failing to block a complete streaming response from being returned print_verbose(f"token_counter failed, assuming prompt tokens is 0") response["usage"]["prompt_tokens"] = 0 - response["usage"]["completion_tokens"] = token_counter(model=model, text=completion_output) - response["usage"]["total_tokens"] = response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] - return convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse()) + response["usage"]["completion_tokens"] = token_counter( + model=model, text=completion_output + ) + response["usage"]["total_tokens"] = ( + response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] + ) + return convert_to_model_response_object( + response_object=response, model_response_object=litellm.ModelResponse() + ) diff --git a/litellm/proxy/__init__.py b/litellm/proxy/__init__.py index b9742821a..b6e690fd5 100644 --- a/litellm/proxy/__init__.py +++ b/litellm/proxy/__init__.py @@ -1 +1 @@ -from . import * \ No newline at end of file +from . import * diff --git a/litellm/proxy/_experimental/post_call_rules.py b/litellm/proxy/_experimental/post_call_rules.py index 12caa5513..d5cbe31f1 100644 --- a/litellm/proxy/_experimental/post_call_rules.py +++ b/litellm/proxy/_experimental/post_call_rules.py @@ -1,4 +1,4 @@ -def my_custom_rule(input): # receives the model response +def my_custom_rule(input): # receives the model response # if len(input) < 5: # trigger fallback if the model response is too short - return False - return True \ No newline at end of file + return False + return True diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 6c7af4ed3..5fe4fd44e 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -3,13 +3,15 @@ from typing import Optional, List, Union, Dict, Literal from datetime import datetime import uuid, json + class LiteLLMBase(BaseModel): """ Implements default functions, all pydantic objects should have. """ + def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() @@ -34,7 +36,7 @@ class ProxyChatCompletionRequest(LiteLLMBase): tools: Optional[List[str]] = None tool_choice: Optional[str] = None functions: Optional[List[str]] = None # soon to be deprecated - function_call: Optional[str] = None # soon to be deprecated + function_call: Optional[str] = None # soon to be deprecated # Optional LiteLLM params caching: Optional[bool] = None @@ -49,7 +51,8 @@ class ProxyChatCompletionRequest(LiteLLMBase): request_timeout: Optional[int] = None class Config: - extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs) + extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs) + class ModelInfoDelete(LiteLLMBase): id: Optional[str] @@ -57,38 +60,37 @@ class ModelInfoDelete(LiteLLMBase): class ModelInfo(LiteLLMBase): id: Optional[str] - mode: Optional[Literal['embedding', 'chat', 'completion']] + mode: Optional[Literal["embedding", "chat", "completion"]] input_cost_per_token: Optional[float] = 0.0 output_cost_per_token: Optional[float] = 0.0 - max_tokens: Optional[int] = 2048 # assume 2048 if not set + max_tokens: Optional[int] = 2048 # assume 2048 if not set # for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model # we look up the base model in model_prices_and_context_window.json - base_model: Optional[Literal - [ - 'gpt-4-1106-preview', - 'gpt-4-32k', - 'gpt-4', - 'gpt-3.5-turbo-16k', - 'gpt-3.5-turbo', - 'text-embedding-ada-002', - ] - ] + base_model: Optional[ + Literal[ + "gpt-4-1106-preview", + "gpt-4-32k", + "gpt-4", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo", + "text-embedding-ada-002", + ] + ] class Config: extra = Extra.allow # Allow extra fields protected_namespaces = () - @root_validator(pre=True) def set_model_info(cls, values): if values.get("id") is None: values.update({"id": str(uuid.uuid4())}) - if values.get("mode") is None: + if values.get("mode") is None: values.update({"mode": None}) - if values.get("input_cost_per_token") is None: + if values.get("input_cost_per_token") is None: values.update({"input_cost_per_token": None}) - if values.get("output_cost_per_token") is None: + if values.get("output_cost_per_token") is None: values.update({"output_cost_per_token": None}) if values.get("max_tokens") is None: values.update({"max_tokens": None}) @@ -97,21 +99,21 @@ class ModelInfo(LiteLLMBase): return values - class ModelParams(LiteLLMBase): model_name: str litellm_params: dict model_info: ModelInfo - + class Config: protected_namespaces = () - + @root_validator(pre=True) def set_model_info(cls, values): if values.get("model_info") is None: values.update({"model_info": ModelInfo()}) return values + class GenerateKeyRequest(LiteLLMBase): duration: Optional[str] = "1h" models: Optional[list] = [] @@ -122,6 +124,7 @@ class GenerateKeyRequest(LiteLLMBase): max_parallel_requests: Optional[int] = None metadata: Optional[dict] = {} + class UpdateKeyRequest(LiteLLMBase): key: str duration: Optional[str] = None @@ -133,10 +136,12 @@ class UpdateKeyRequest(LiteLLMBase): max_parallel_requests: Optional[int] = None metadata: Optional[dict] = {} -class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth + +class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth """ Return the row in the db """ + api_key: Optional[str] = None models: list = [] aliases: dict = {} @@ -147,45 +152,84 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k duration: str = "1h" metadata: dict = {} + class GenerateKeyResponse(LiteLLMBase): key: str expires: Optional[datetime] user_id: str + class _DeleteKeyObject(LiteLLMBase): key: str + class DeleteKeyRequest(LiteLLMBase): keys: List[_DeleteKeyObject] + class NewUserRequest(GenerateKeyRequest): max_budget: Optional[float] = None + class NewUserResponse(GenerateKeyResponse): max_budget: Optional[float] = None + class ConfigGeneralSettings(LiteLLMBase): """ Documents all the fields supported by `general_settings` in config.yaml """ - completion_model: Optional[str] = Field(None, description="proxy level default model for all chat completion calls") - use_azure_key_vault: Optional[bool] = Field(None, description="load keys from azure key vault") - master_key: Optional[str] = Field(None, description="require a key for all calls to proxy") - database_url: Optional[str] = Field(None, description="connect to a postgres db - needed for generating temporary keys + tracking spend / key") - otel: Optional[bool] = Field(None, description="[BETA] OpenTelemetry support - this might change, use with caution.") - custom_auth: Optional[str] = Field(None, description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth") - max_parallel_requests: Optional[int] = Field(None, description="maximum parallel requests for each api key") - infer_model_from_keys: Optional[bool] = Field(None, description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)") - background_health_checks: Optional[bool] = Field(None, description="run health checks in background") - health_check_interval: int = Field(300, description="background health check interval in seconds") - + + completion_model: Optional[str] = Field( + None, description="proxy level default model for all chat completion calls" + ) + use_azure_key_vault: Optional[bool] = Field( + None, description="load keys from azure key vault" + ) + master_key: Optional[str] = Field( + None, description="require a key for all calls to proxy" + ) + database_url: Optional[str] = Field( + None, + description="connect to a postgres db - needed for generating temporary keys + tracking spend / key", + ) + otel: Optional[bool] = Field( + None, + description="[BETA] OpenTelemetry support - this might change, use with caution.", + ) + custom_auth: Optional[str] = Field( + None, + description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth", + ) + max_parallel_requests: Optional[int] = Field( + None, description="maximum parallel requests for each api key" + ) + infer_model_from_keys: Optional[bool] = Field( + None, + description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)", + ) + background_health_checks: Optional[bool] = Field( + None, description="run health checks in background" + ) + health_check_interval: int = Field( + 300, description="background health check interval in seconds" + ) + class ConfigYAML(LiteLLMBase): """ Documents all the fields supported by the config.yaml """ - model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs") - litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache") + + model_list: Optional[List[ModelParams]] = Field( + None, + description="List of supported models on the server, with model-specific configs", + ) + litellm_settings: Optional[dict] = Field( + None, + description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache", + ) general_settings: Optional[ConfigGeneralSettings] = None + class Config: protected_namespaces = () diff --git a/litellm/proxy/custom_auth.py b/litellm/proxy/custom_auth.py index 933479708..416b66682 100644 --- a/litellm/proxy/custom_auth.py +++ b/litellm/proxy/custom_auth.py @@ -1,14 +1,16 @@ from litellm.proxy._types import UserAPIKeyAuth from fastapi import Request from dotenv import load_dotenv -import os +import os load_dotenv() -async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: - try: + + +async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: + try: modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234" if api_key == modified_master_key: return UserAPIKeyAuth(api_key=api_key) raise Exception - except: - raise Exception \ No newline at end of file + except: + raise Exception diff --git a/litellm/proxy/custom_callbacks.py b/litellm/proxy/custom_callbacks.py index dfcd55520..c3e1344d2 100644 --- a/litellm/proxy/custom_callbacks.py +++ b/litellm/proxy/custom_callbacks.py @@ -4,17 +4,19 @@ import sys, os, traceback sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system path from litellm.integrations.custom_logger import CustomLogger import litellm import inspect + # This file includes the custom callbacks for LiteLLM Proxy # Once defined, these can be passed in proxy_config.yaml def print_verbose(print_statement): - if litellm.set_verbose: - print(print_statement) # noqa + if litellm.set_verbose: + print(print_statement) # noqa + class MyCustomHandler(CustomLogger): def __init__(self): @@ -23,36 +25,38 @@ class MyCustomHandler(CustomLogger): print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger") try: print_verbose(f"Logger Initialized with following methods:") - methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))] - + methods = [ + method + for method in dir(self) + if inspect.ismethod(getattr(self, method)) + ] + # Pretty print_verbose the methods for method in methods: print_verbose(f" - {method}") print_verbose(f"{reset_color_code}") except: pass - - def log_pre_api_call(self, model, messages, kwargs): + def log_pre_api_call(self, model, messages, kwargs): print_verbose(f"Pre-API Call") - - def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): print_verbose(f"Post-API Call") def log_stream_event(self, kwargs, response_obj, start_time, end_time): print_verbose(f"On Stream") - - def log_success_event(self, kwargs, response_obj, start_time, end_time): + + def log_success_event(self, kwargs, response_obj, start_time, end_time): print_verbose("On Success!") - async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): print_verbose(f"On Async Success!") response_cost = litellm.completion_cost(completion_response=response_obj) assert response_cost > 0.0 return - async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: print_verbose(f"On Async Failure !") except Exception as e: @@ -64,4 +68,4 @@ proxy_handler_instance = MyCustomHandler() # need to set litellm.callbacks = [customHandler] # on the proxy -# litellm.success_callback = [async_on_succes_logger] \ No newline at end of file +# litellm.success_callback = [async_on_succes_logger] diff --git a/litellm/proxy/health_check.py b/litellm/proxy/health_check.py index 3c7ff599e..53dc2cf72 100644 --- a/litellm/proxy/health_check.py +++ b/litellm/proxy/health_check.py @@ -12,25 +12,16 @@ from litellm._logging import print_verbose logger = logging.getLogger(__name__) -ILLEGAL_DISPLAY_PARAMS = [ - "messages", - "api_key" -] +ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key"] def _get_random_llm_message(): """ Get a random message from the LLM. """ - messages = [ - "Hey how's it going?", - "What's 1 + 1?" - ] + messages = ["Hey how's it going?", "What's 1 + 1?"] - - return [ - {"role": "user", "content": random.choice(messages)} - ] + return [{"role": "user", "content": random.choice(messages)}] def _clean_litellm_params(litellm_params: dict): @@ -44,34 +35,40 @@ async def _perform_health_check(model_list: list): """ Perform a health check for each model in the list. """ + async def _check_img_gen_model(model_params: dict): model_params.pop("messages", None) model_params["prompt"] = "test from litellm" try: await litellm.aimage_generation(**model_params) except Exception as e: - print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}") + print_verbose( + f"Health check failed for model {model_params['model']}. Error: {e}" + ) return False return True - + async def _check_embedding_model(model_params: dict): model_params.pop("messages", None) model_params["input"] = ["test from litellm"] try: await litellm.aembedding(**model_params) except Exception as e: - print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}") + print_verbose( + f"Health check failed for model {model_params['model']}. Error: {e}" + ) return False return True - async def _check_model(model_params: dict): try: await litellm.acompletion(**model_params) - except Exception as e: - print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}") + except Exception as e: + print_verbose( + f"Health check failed for model {model_params['model']}. Error: {e}" + ) return False - + return True tasks = [] @@ -104,9 +101,9 @@ async def _perform_health_check(model_list: list): return healthy_endpoints, unhealthy_endpoints - - -async def perform_health_check(model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None): +async def perform_health_check( + model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None +): """ Perform a health check on the system. @@ -115,7 +112,9 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl """ if not model_list: if cli_model: - model_list = [{"model_name": cli_model, "litellm_params": {"model": cli_model}}] + model_list = [ + {"model_name": cli_model, "litellm_params": {"model": cli_model}} + ] else: return [], [] @@ -125,5 +124,3 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list) return healthy_endpoints, unhealthy_endpoints - - \ No newline at end of file diff --git a/litellm/proxy/hooks/max_budget_limiter.py b/litellm/proxy/hooks/max_budget_limiter.py index e4dbdd5e7..fa24c9f0f 100644 --- a/litellm/proxy/hooks/max_budget_limiter.py +++ b/litellm/proxy/hooks/max_budget_limiter.py @@ -6,35 +6,42 @@ from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException import json, traceback -class MaxBudgetLimiter(CustomLogger): + +class MaxBudgetLimiter(CustomLogger): # Class variables or attributes def __init__(self): pass def print_verbose(self, print_statement): - if litellm.set_verbose is True: - print(print_statement) # noqa - - async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str): - try: + if litellm.set_verbose is True: + print(print_statement) # noqa + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): + try: self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook") cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id" user_row = cache.get_cache(cache_key) - if user_row is None: # value not yet cached - return + if user_row is None: # value not yet cached + return max_budget = user_row["max_budget"] curr_spend = user_row["spend"] if max_budget is None: return - - if curr_spend is None: - return - + + if curr_spend is None: + return + # CHECK IF REQUEST ALLOWED if curr_spend >= max_budget: raise HTTPException(status_code=429, detail="Max budget limit reached.") - except HTTPException as e: + except HTTPException as e: raise e except Exception as e: - traceback.print_exc() \ No newline at end of file + traceback.print_exc() diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 98ee231b6..38247cbe0 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -5,18 +5,25 @@ from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException -class MaxParallelRequestsHandler(CustomLogger): + +class MaxParallelRequestsHandler(CustomLogger): user_api_key_cache = None + # Class variables or attributes def __init__(self): pass def print_verbose(self, print_statement): - if litellm.set_verbose is True: - print(print_statement) # noqa + if litellm.set_verbose is True: + print(print_statement) # noqa - - async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str): + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") api_key = user_api_key_dict.api_key max_parallel_requests = user_api_key_dict.max_parallel_requests @@ -26,8 +33,8 @@ class MaxParallelRequestsHandler(CustomLogger): if max_parallel_requests is None: return - - self.user_api_key_cache = cache # save the api key cache for updating the value + + self.user_api_key_cache = cache # save the api key cache for updating the value # CHECK IF REQUEST ALLOWED request_count_api_key = f"{api_key}_request_count" @@ -35,56 +42,67 @@ class MaxParallelRequestsHandler(CustomLogger): self.print_verbose(f"current: {current}") if current is None: cache.set_cache(request_count_api_key, 1) - elif int(current) < max_parallel_requests: + elif int(current) < max_parallel_requests: # Increase count for this token cache.set_cache(request_count_api_key, int(current) + 1) - else: - raise HTTPException(status_code=429, detail="Max parallel request limit reached.") - + else: + raise HTTPException( + status_code=429, detail="Max parallel request limit reached." + ) async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - try: + try: self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING") user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] if user_api_key is None: return - - if self.user_api_key_cache is None: + + if self.user_api_key_cache is None: return - + request_count_api_key = f"{user_api_key}_request_count" # check if it has collected an entire stream response - self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}") + self.print_verbose( + f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}" + ) if "complete_streaming_response" in kwargs or kwargs["stream"] != True: # Decrease count for this token - current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 + current = ( + self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 + ) new_val = current - 1 self.print_verbose(f"updated_value in success call: {new_val}") self.user_api_key_cache.set_cache(request_count_api_key, new_val) - except Exception as e: - self.print_verbose(e) # noqa + except Exception as e: + self.print_verbose(e) # noqa - async def async_log_failure_call(self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception): + async def async_log_failure_call( + self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception + ): try: self.print_verbose(f"Inside Max Parallel Request Failure Hook") api_key = user_api_key_dict.api_key if api_key is None: return - - if self.user_api_key_cache is None: + + if self.user_api_key_cache is None: return - + ## decrement call count if call failed - if (hasattr(original_exception, "status_code") - and original_exception.status_code == 429 - and "Max parallel request limit reached" in str(original_exception)): - pass # ignore failed calls due to max limit being reached - else: + if ( + hasattr(original_exception, "status_code") + and original_exception.status_code == 429 + and "Max parallel request limit reached" in str(original_exception) + ): + pass # ignore failed calls due to max limit being reached + else: request_count_api_key = f"{api_key}_request_count" # Decrease count for this token - current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 + current = ( + self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 + ) new_val = current - 1 self.print_verbose(f"updated_value in failure call: {new_val}") self.user_api_key_cache.set_cache(request_count_api_key, new_val) except Exception as e: - self.print_verbose(f"An exception occurred - {str(e)}") # noqa \ No newline at end of file + self.print_verbose(f"An exception occurred - {str(e)}") # noqa diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 761319d15..39381c673 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -6,85 +6,202 @@ from datetime import datetime import importlib from dotenv import load_dotenv import operator + sys.path.append(os.getcwd()) config_filename = "litellm.secrets" # Using appdirs to determine user-specific config path config_dir = appdirs.user_config_dir("litellm") -user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)) +user_config_path = os.getenv( + "LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename) +) load_dotenv() from importlib import resources import shutil + telemetry = None + def run_ollama_serve(): try: - command = ['ollama', 'serve'] - - with open(os.devnull, 'w') as devnull: + command = ["ollama", "serve"] + + with open(os.devnull, "w") as devnull: process = subprocess.Popen(command, stdout=devnull, stderr=devnull) except Exception as e: - print(f""" + print( + f""" LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` - """) # noqa + """ + ) # noqa + def clone_subfolder(repo_url, subfolder, destination): - # Clone the full repo - repo_name = repo_url.split('/')[-1] - repo_master = os.path.join(destination, "repo_master") - subprocess.run(['git', 'clone', repo_url, repo_master]) + # Clone the full repo + repo_name = repo_url.split("/")[-1] + repo_master = os.path.join(destination, "repo_master") + subprocess.run(["git", "clone", repo_url, repo_master]) - # Move into the subfolder - subfolder_path = os.path.join(repo_master, subfolder) + # Move into the subfolder + subfolder_path = os.path.join(repo_master, subfolder) - # Copy subfolder to destination - for file_name in os.listdir(subfolder_path): - source = os.path.join(subfolder_path, file_name) - if os.path.isfile(source): - shutil.copy(source, destination) - else: - dest_path = os.path.join(destination, file_name) - shutil.copytree(source, dest_path) + # Copy subfolder to destination + for file_name in os.listdir(subfolder_path): + source = os.path.join(subfolder_path, file_name) + if os.path.isfile(source): + shutil.copy(source, destination) + else: + dest_path = os.path.join(destination, file_name) + shutil.copytree(source, dest_path) + + # Remove cloned repo folder + subprocess.run(["rm", "-rf", os.path.join(destination, "repo_master")]) + feature_telemetry(feature="create-proxy") - # Remove cloned repo folder - subprocess.run(['rm', '-rf', os.path.join(destination, "repo_master")]) - feature_telemetry(feature="create-proxy") def is_port_in_use(port): import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 + return s.connect_ex(("localhost", port)) == 0 + @click.command() -@click.option('--host', default='0.0.0.0', help='Host for the server to listen on.') -@click.option('--port', default=8000, help='Port to bind the server to.') -@click.option('--num_workers', default=1, help='Number of uvicorn workers to spin up') -@click.option('--api_base', default=None, help='API base URL.') -@click.option('--api_version', default="2023-07-01-preview", help='For azure - pass in the api version.') -@click.option('--model', '-m', default=None, help='The model name to pass to litellm expects') -@click.option('--alias', default=None, help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")') -@click.option('--add_key', default=None, help='The model name to pass to litellm expects') -@click.option('--headers', default=None, help='headers for the API call') -@click.option('--save', is_flag=True, type=bool, help='Save the model-specific config') -@click.option('--debug', default=False, is_flag=True, type=bool, help='To debug the input') -@click.option('--use_queue', default=False, is_flag=True, type=bool, help='To use celery workers for async endpoints') -@click.option('--temperature', default=None, type=float, help='Set temperature for the model') -@click.option('--max_tokens', default=None, type=int, help='Set max tokens for the model') -@click.option('--request_timeout', default=600, type=int, help='Set timeout in seconds for completion calls') -@click.option('--drop_params', is_flag=True, help='Drop any unmapped params') -@click.option('--add_function_to_prompt', is_flag=True, help='If function passed but unsupported, pass it as prompt') -@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`') -@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`') -@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`') -@click.option('--version', '-v', default=False, is_flag=True, type=bool, help='Print LiteLLM version') -@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.') -@click.option('--health', flag_value=True, help='Make a chat/completions request to all llms in config.yaml') -@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to') -@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response') -@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with') -@click.option('--local', is_flag=True, default=False, help='for local debugging') -def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health, version): +@click.option("--host", default="0.0.0.0", help="Host for the server to listen on.") +@click.option("--port", default=8000, help="Port to bind the server to.") +@click.option("--num_workers", default=1, help="Number of uvicorn workers to spin up") +@click.option("--api_base", default=None, help="API base URL.") +@click.option( + "--api_version", + default="2023-07-01-preview", + help="For azure - pass in the api version.", +) +@click.option( + "--model", "-m", default=None, help="The model name to pass to litellm expects" +) +@click.option( + "--alias", + default=None, + help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")', +) +@click.option( + "--add_key", default=None, help="The model name to pass to litellm expects" +) +@click.option("--headers", default=None, help="headers for the API call") +@click.option("--save", is_flag=True, type=bool, help="Save the model-specific config") +@click.option( + "--debug", default=False, is_flag=True, type=bool, help="To debug the input" +) +@click.option( + "--use_queue", + default=False, + is_flag=True, + type=bool, + help="To use celery workers for async endpoints", +) +@click.option( + "--temperature", default=None, type=float, help="Set temperature for the model" +) +@click.option( + "--max_tokens", default=None, type=int, help="Set max tokens for the model" +) +@click.option( + "--request_timeout", + default=600, + type=int, + help="Set timeout in seconds for completion calls", +) +@click.option("--drop_params", is_flag=True, help="Drop any unmapped params") +@click.option( + "--add_function_to_prompt", + is_flag=True, + help="If function passed but unsupported, pass it as prompt", +) +@click.option( + "--config", + "-c", + default=None, + help="Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`", +) +@click.option( + "--max_budget", + default=None, + type=float, + help="Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`", +) +@click.option( + "--telemetry", + default=True, + type=bool, + help="Helps us know if people are using this feature. Turn this off by doing `--telemetry False`", +) +@click.option( + "--version", + "-v", + default=False, + is_flag=True, + type=bool, + help="Print LiteLLM version", +) +@click.option( + "--logs", + flag_value=False, + type=int, + help='Gets the "n" most recent logs. By default gets most recent log.', +) +@click.option( + "--health", + flag_value=True, + help="Make a chat/completions request to all llms in config.yaml", +) +@click.option( + "--test", + flag_value=True, + help="proxy chat completions url to make a test request to", +) +@click.option( + "--test_async", + default=False, + is_flag=True, + help="Calls async endpoints /queue/requests and /queue/response", +) +@click.option( + "--num_requests", + default=10, + type=int, + help="Number of requests to hit async endpoint with", +) +@click.option("--local", is_flag=True, default=False, help="for local debugging") +def run_server( + host, + port, + api_base, + api_version, + model, + alias, + add_key, + headers, + save, + debug, + temperature, + max_tokens, + request_timeout, + drop_params, + add_function_to_prompt, + config, + max_budget, + telemetry, + logs, + test, + local, + num_workers, + test_async, + num_requests, + use_queue, + health, + version, +): global feature_telemetry args = locals() if local: @@ -92,51 +209,60 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers else: try: from .proxy_server import app, save_worker_config, usage_telemetry - except ImportError as e: + except ImportError as e: from proxy_server import app, save_worker_config, usage_telemetry feature_telemetry = usage_telemetry if logs is not None: - if logs == 0: # default to 1 + if logs == 0: # default to 1 logs = 1 try: - with open('api_log.json') as f: + with open("api_log.json") as f: data = json.load(f) - # convert keys to datetime objects - log_times = {datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items()} + # convert keys to datetime objects + log_times = { + datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items() + } - # sort by timestamp - sorted_times = sorted(log_times.items(), key=operator.itemgetter(0), reverse=True) + # sort by timestamp + sorted_times = sorted( + log_times.items(), key=operator.itemgetter(0), reverse=True + ) # get n recent logs - recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]} + recent_logs = { + k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs] + } - print(json.dumps(recent_logs, indent=4)) # noqa + print(json.dumps(recent_logs, indent=4)) # noqa except: raise Exception("LiteLLM: No logs saved!") return if version == True: pkg_version = importlib.metadata.version("litellm") - click.echo(f'\nLiteLLM: Current Version = {pkg_version}\n') + click.echo(f"\nLiteLLM: Current Version = {pkg_version}\n") return - if model and "ollama" in model and api_base is None: + if model and "ollama" in model and api_base is None: run_ollama_serve() - if test_async is True: + if test_async is True: import requests, concurrent, time + api_base = f"http://{host}:{port}" - def _make_openai_completion(): + def _make_openai_completion(): data = { - "model": "gpt-3.5-turbo", - "messages": [{"role": "user", "content": "Write a short poem about the moon"}] + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "Write a short poem about the moon"} + ], } response = requests.post("http://0.0.0.0:8000/queue/request", json=data) response = response.json() - while True: - try: + while True: + try: url = response["url"] polling_url = f"{api_base}{url}" polling_response = requests.get(polling_url) @@ -146,7 +272,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers if status == "finished": llm_response = polling_response["result"] break - print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa + print( + f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}" + ) # noqa time.sleep(0.5) except Exception as e: print("got exception in polling", e) @@ -159,7 +287,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers futures = [] start_time = time.time() # Make concurrent calls - with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_calls) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=concurrent_calls + ) as executor: for _ in range(concurrent_calls): futures.append(executor.submit(_make_openai_completion)) @@ -171,7 +301,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers failed_calls = 0 for future in futures: - if future.done(): + if future.done(): if future.result() is not None: successful_calls += 1 else: @@ -185,58 +315,86 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers return if health != False: import requests + print("\nLiteLLM: Health Testing models in config") response = requests.get(url=f"http://{host}:{port}/health") print(json.dumps(response.json(), indent=4)) return if test != False: - click.echo('\nLiteLLM: Making a test ChatCompletions request to your proxy') + click.echo("\nLiteLLM: Making a test ChatCompletions request to your proxy") import openai - if test == True: # flag value set - api_base = f"http://{host}:{port}" - else: - api_base = test - client = openai.OpenAI( - api_key="My API Key", - base_url=api_base - ) - response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [ - { - "role": "user", - "content": "this is a test request, write a short poem" - } - ], max_tokens=256) - click.echo(f'\nLiteLLM: response from proxy {response}') + if test == True: # flag value set + api_base = f"http://{host}:{port}" + else: + api_base = test + client = openai.OpenAI(api_key="My API Key", base_url=api_base) + + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "this is a test request, write a short poem", + } + ], + max_tokens=256, + ) + click.echo(f"\nLiteLLM: response from proxy {response}") print("\n Making streaming request to proxy") - response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [ - { - "role": "user", - "content": "this is a test request, write a short poem" - } - ], - stream=True, + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "this is a test request, write a short poem", + } + ], + stream=True, ) for chunk in response: - click.echo(f'LiteLLM: streaming response from proxy {chunk}') + click.echo(f"LiteLLM: streaming response from proxy {chunk}") print("\n making completion request to proxy") - response = client.completions.create(model="gpt-3.5-turbo", prompt='this is a test request, write a short poem') + response = client.completions.create( + model="gpt-3.5-turbo", prompt="this is a test request, write a short poem" + ) print(response) return else: if headers: headers = json.loads(headers) - save_worker_config(model=model, alias=alias, api_base=api_base, api_version=api_version, debug=debug, temperature=temperature, max_tokens=max_tokens, request_timeout=request_timeout, max_budget=max_budget, telemetry=telemetry, drop_params=drop_params, add_function_to_prompt=add_function_to_prompt, headers=headers, save=save, config=config, use_queue=use_queue) + save_worker_config( + model=model, + alias=alias, + api_base=api_base, + api_version=api_version, + debug=debug, + temperature=temperature, + max_tokens=max_tokens, + request_timeout=request_timeout, + max_budget=max_budget, + telemetry=telemetry, + drop_params=drop_params, + add_function_to_prompt=add_function_to_prompt, + headers=headers, + save=save, + config=config, + use_queue=use_queue, + ) try: import uvicorn except: - raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`") + raise ImportError( + "Uvicorn needs to be imported. Run - `pip install uvicorn`" + ) if port == 8000 and is_port_in_use(port): port = random.randint(1024, 49152) - uvicorn.run("litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers) + uvicorn.run( + "litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers + ) if __name__ == "__main__": diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0264ad2f9..f4a68eb97 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -7,6 +7,7 @@ import secrets, subprocess import hashlib, uuid import warnings import importlib + messages: list = [] sys.path.insert( 0, os.path.abspath("../..") @@ -32,8 +33,8 @@ except ImportError: "fastapi", "appdirs", "backoff", - "pyyaml", - "orjson" + "pyyaml", + "orjson", ] ) import uvicorn @@ -88,19 +89,29 @@ def generate_feedback_box(): print() print() + import litellm from litellm.proxy.utils import ( - PrismaClient, - get_instance_fn, + PrismaClient, + get_instance_fn, ProxyLogging, - _cache_user_row + _cache_user_row, ) import pydantic from litellm.proxy._types import * from litellm.caching import DualCache from litellm.proxy.health_check import perform_health_check + litellm.suppress_debug_info = True -from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks, Header +from fastapi import ( + FastAPI, + Request, + HTTPException, + status, + Depends, + BackgroundTasks, + Header, +) from fastapi.routing import APIRouter from fastapi.security import OAuth2PasswordBearer from fastapi.encoders import jsonable_encoder @@ -111,7 +122,11 @@ import json import logging from typing import Union -app = FastAPI(docs_url="/", title="LiteLLM API", description="Proxy Server to call 100+ LLMs in the OpenAI format") +app = FastAPI( + docs_url="/", + title="LiteLLM API", + description="Proxy Server to call 100+ LLMs in the OpenAI format", +) router = APIRouter() origins = ["*"] @@ -122,7 +137,9 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"], ) -def log_input_output(request, response, custom_logger=None): + + +def log_input_output(request, response, custom_logger=None): if custom_logger is not None: custom_logger(request, response) global otel_logging @@ -137,35 +154,45 @@ def log_input_output(request, response, custom_logger=None): # Initialize OpenTelemetry components otlp_host = os.environ.get("OTEL_ENDPOINT", "localhost:4317") otlp_exporter = OTLPSpanExporter(endpoint=otlp_host, insecure=True) - resource = Resource.create({ - "service.name": "LiteLLM Proxy", - }) + resource = Resource.create( + { + "service.name": "LiteLLM Proxy", + } + ) trace.set_tracer_provider(TracerProvider(resource=resource)) tracer = trace.get_tracer(__name__) span_processor = SimpleSpanProcessor(otlp_exporter) trace.get_tracer_provider().add_span_processor(span_processor) with tracer.start_as_current_span("litellm-completion") as current_span: - input_event_attributes = {f"{key}": str(value) for key, value in dict(request).items() if value is not None} + input_event_attributes = { + f"{key}": str(value) + for key, value in dict(request).items() + if value is not None + } # Log the input event with attributes current_span.add_event( - name="LiteLLM: Request Input", - attributes=input_event_attributes + name="LiteLLM: Request Input", attributes=input_event_attributes ) - event_headers = {f"{key}": str(value) for key, value in dict(request.headers).items() if value is not None} + event_headers = { + f"{key}": str(value) + for key, value in dict(request.headers).items() + if value is not None + } current_span.add_event( - name="LiteLLM: Request Headers", - attributes=event_headers + name="LiteLLM: Request Headers", attributes=event_headers ) input_event_attributes.update(event_headers) - input_event_attributes.update({f"{key}": str(value) for key, value in dict(response).items()}) + input_event_attributes.update( + {f"{key}": str(value) for key, value in dict(response).items()} + ) current_span.add_event( - name="LiteLLM: Request Outpu", - attributes=input_event_attributes + name="LiteLLM: Request Outpu", attributes=input_event_attributes ) return True + from typing import Dict api_key_header = APIKeyHeader(name="Authorization", auto_error=False) @@ -179,7 +206,7 @@ user_telemetry = True user_config = None user_headers = None user_config_file_path = f"config_{time.time()}.yaml" -local_logging = True # writes logs to a local api_log.json file for debugging +local_logging = True # writes logs to a local api_log.json file for debugging experimental = False #### GLOBAL VARIABLES #### llm_router: Optional[litellm.Router] = None @@ -199,10 +226,12 @@ health_check_results = {} queue: List = [] ### INITIALIZE GLOBAL LOGGING OBJECT ### proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) -### REDIS QUEUE ### +### REDIS QUEUE ### async_result = None -celery_app_conn = None -celery_fn = None # Redis Queue for handling requests +celery_app_conn = None +celery_fn = None # Redis Queue for handling requests + + #### HELPER FUNCTIONS #### def print_verbose(print_statement): try: @@ -212,6 +241,7 @@ def print_verbose(print_statement): except: pass + def usage_telemetry( feature: str, ): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off @@ -221,35 +251,40 @@ def usage_telemetry( target=litellm.utils.litellm_telemetry, args=(data,), daemon=True ).start() -def _get_bearer_token(api_key: str): - assert api_key.startswith("Bearer ") # ensure Bearer token passed in - api_key = api_key.replace("Bearer ", "") # extract the token + +def _get_bearer_token(api_key: str): + assert api_key.startswith("Bearer ") # ensure Bearer token passed in + api_key = api_key.replace("Bearer ", "") # extract the token return api_key -def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: + +def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: try: - return pydantic_obj.model_dump() # type: ignore + return pydantic_obj.model_dump() # type: ignore except: # if using pydantic v1 return pydantic_obj.dict() -async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth: + +async def user_api_key_auth( + request: Request, api_key: str = fastapi.Security(api_key_header) +) -> UserAPIKeyAuth: global master_key, prisma_client, llm_model_list, user_custom_auth - try: - if isinstance(api_key, str): + try: + if isinstance(api_key, str): api_key = _get_bearer_token(api_key=api_key) ### USER-DEFINED AUTH FUNCTION ### if user_custom_auth: response = await user_custom_auth(request=request, api_key=api_key) return UserAPIKeyAuth.model_validate(response) - + if master_key is None: if isinstance(api_key, str): return UserAPIKeyAuth(api_key=api_key) else: return UserAPIKeyAuth() - - if api_key is None: # only require api key if master key is set + + if api_key is None: # only require api key if master key is set raise Exception(f"No api key passed in.") route: str = request.url.path @@ -258,34 +293,44 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap is_master_key_valid = secrets.compare_digest(api_key, master_key) if is_master_key_valid: return UserAPIKeyAuth(api_key=master_key) - - if (route.startswith("/key/") or route.startswith("/user/")) and not is_master_key_valid: - raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users") - if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error + if ( + route.startswith("/key/") or route.startswith("/user/") + ) and not is_master_key_valid: + raise Exception( + f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users" + ) + + if ( + prisma_client is None + ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error raise Exception("No connected db.") - + ## check for cache hit (In-Memory Cache) valid_token = user_api_key_cache.get_cache(key=api_key) print(f"valid_token from cache: {valid_token}") - if valid_token is None: - ## check db + if valid_token is None: + ## check db print(f"api key: {api_key}") - valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc)) + valid_token = await prisma_client.get_data( + token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc) + ) print(f"valid token from prisma: {valid_token}") user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) - elif valid_token is not None: + elif valid_token is not None: print(f"API Key Cache Hit!") if valid_token: litellm.model_alias_map = valid_token.aliases config = valid_token.config if config != {}: model_list = config.get("model_list", []) - llm_model_list = model_list + llm_model_list = model_list print("\n new llm router model list", llm_model_list) - if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called + if ( + len(valid_token.models) == 0 + ): # assume an empty model list means all models are allowed to be called pass - else: + else: try: data = await request.json() except json.JSONDecodeError: @@ -303,34 +348,46 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap This makes the user row data accessible to pre-api call hooks. """ - asyncio.create_task(_cache_user_row(user_id=valid_token.user_id, cache=user_api_key_cache, db=prisma_client)) + asyncio.create_task( + _cache_user_row( + user_id=valid_token.user_id, + cache=user_api_key_cache, + db=prisma_client, + ) + ) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) - else: + else: raise Exception(f"Invalid token") - except Exception as e: + except Exception as e: print(f"An exception occurred - {traceback.format_exc()}") - if isinstance(e, HTTPException): + if isinstance(e, HTTPException): raise e - else: + else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key", ) -def prisma_setup(database_url: Optional[str]): + +def prisma_setup(database_url: Optional[str]): global prisma_client, proxy_logging_obj, user_api_key_cache if database_url is not None: - try: - prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj) + try: + prisma_client = PrismaClient( + database_url=database_url, proxy_logging_obj=proxy_logging_obj + ) except Exception as e: - print("Error when initializing prisma, Ensure you run pip install prisma", e) + print( + "Error when initializing prisma, Ensure you run pip install prisma", e + ) -def load_from_azure_key_vault(use_azure_key_vault: bool = False): + +def load_from_azure_key_vault(use_azure_key_vault: bool = False): if use_azure_key_vault is False: return - - try: + + try: from azure.keyvault.secrets import SecretClient from azure.identity import ClientSecretCredential @@ -340,70 +397,103 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False): # Set your Azure AD application/client ID, client secret, and tenant ID client_id = os.getenv("AZURE_CLIENT_ID", None) client_secret = os.getenv("AZURE_CLIENT_SECRET", None) - tenant_id = os.getenv("AZURE_TENANT_ID", None) + tenant_id = os.getenv("AZURE_TENANT_ID", None) - if KVUri is not None and client_id is not None and client_secret is not None and tenant_id is not None: + if ( + KVUri is not None + and client_id is not None + and client_secret is not None + and tenant_id is not None + ): # Initialize the ClientSecretCredential - credential = ClientSecretCredential(client_id=client_id, client_secret=client_secret, tenant_id=tenant_id) + credential = ClientSecretCredential( + client_id=client_id, client_secret=client_secret, tenant_id=tenant_id + ) # Create the SecretClient using the credential client = SecretClient(vault_url=KVUri, credential=credential) - - litellm.secret_manager_client = client - else: - raise Exception(f"Missing KVUri or client_id or client_secret or tenant_id from environment") - except Exception as e: - print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`") -def cost_tracking(): + litellm.secret_manager_client = client + else: + raise Exception( + f"Missing KVUri or client_id or client_secret or tenant_id from environment" + ) + except Exception as e: + print( + "Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`" + ) + + +def cost_tracking(): global prisma_client if prisma_client is not None: if isinstance(litellm.success_callback, list): print("setting litellm success callback to track cost") - if (track_cost_callback) not in litellm.success_callback: # type: ignore - litellm.success_callback.append(track_cost_callback) # type: ignore + if (track_cost_callback) not in litellm.success_callback: # type: ignore + litellm.success_callback.append(track_cost_callback) # type: ignore + async def track_cost_callback( - kwargs, # kwargs to completion - completion_response: litellm.ModelResponse, # response from completion - start_time = None, - end_time = None, # start/end time for completion + kwargs, # kwargs to completion + completion_response: litellm.ModelResponse, # response from completion + start_time=None, + end_time=None, # start/end time for completion ): global prisma_client try: # check if it has collected an entire stream response - print(f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}") + print( + f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" + ) if "complete_streaming_response" in kwargs: - # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost - completion_response=kwargs["complete_streaming_response"] - response_cost = litellm.completion_cost(completion_response=completion_response) + # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost + completion_response = kwargs["complete_streaming_response"] + response_cost = litellm.completion_cost( + completion_response=completion_response + ) print("streaming response_cost", response_cost) - user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) - user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None) - if user_api_key and prisma_client: - await update_prisma_database(token=user_api_key, response_cost=response_cost) - elif kwargs["stream"] == False: # for non streaming responses - response_cost = litellm.completion_cost(completion_response=completion_response) - user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) - user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None) - if user_api_key and prisma_client: - await update_prisma_database(token=user_api_key, response_cost=response_cost, user_id=user_id) + user_api_key = kwargs["litellm_params"]["metadata"].get( + "user_api_key", None + ) + user_id = kwargs["litellm_params"]["metadata"].get( + "user_api_key_user_id", None + ) + if user_api_key and prisma_client: + await update_prisma_database( + token=user_api_key, response_cost=response_cost + ) + elif kwargs["stream"] == False: # for non streaming responses + response_cost = litellm.completion_cost( + completion_response=completion_response + ) + user_api_key = kwargs["litellm_params"]["metadata"].get( + "user_api_key", None + ) + user_id = kwargs["litellm_params"]["metadata"].get( + "user_api_key_user_id", None + ) + if user_api_key and prisma_client: + await update_prisma_database( + token=user_api_key, response_cost=response_cost, user_id=user_id + ) except Exception as e: print(f"error in tracking cost callback - {str(e)}") + async def update_prisma_database(token, response_cost, user_id=None): try: print(f"Enters prisma db call, token: {token}; user_id: {user_id}") + ### UPDATE USER SPEND ### - async def _update_user_db(): + async def _update_user_db(): if user_id is None: - return + return existing_spend_obj = await prisma_client.get_data(user_id=user_id) - if existing_spend_obj is None: + if existing_spend_obj is None: existing_spend = 0 else: existing_spend = existing_spend_obj.spend - + # Calculate the new cost by adding the existing cost and response_cost new_spend = existing_spend + response_cost @@ -412,11 +502,11 @@ async def update_prisma_database(token, response_cost, user_id=None): await prisma_client.update_data(user_id=user_id, data={"spend": new_spend}) ### UPDATE KEY SPEND ### - async def _update_key_db(): + async def _update_key_db(): # Fetch the existing cost for the given token existing_spend_obj = await prisma_client.get_data(token=token) print(f"existing spend: {existing_spend_obj}") - if existing_spend_obj is None: + if existing_spend_obj is None: existing_spend = 0 else: existing_spend = existing_spend_obj.spend @@ -426,7 +516,8 @@ async def update_prisma_database(token, response_cost, user_id=None): print(f"new cost: {new_spend}") # Update the cost column for the given token await prisma_client.update_data(token=token, data={"spend": new_spend}) - tasks = [] + + tasks = [] tasks.append(_update_user_db()) tasks.append(_update_key_db()) await asyncio.gather(*tasks) @@ -434,26 +525,32 @@ async def update_prisma_database(token, response_cost, user_id=None): print(f"Error updating Prisma database: {traceback.format_exc()}") pass + def run_ollama_serve(): try: - command = ['ollama', 'serve'] + command = ["ollama", "serve"] - with open(os.devnull, 'w') as devnull: + with open(os.devnull, "w") as devnull: process = subprocess.Popen(command, stdout=devnull, stderr=devnull) except Exception as e: - print(f""" + print( + f""" LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` - """) + """ + ) + async def _run_background_health_check(): """ - Periodically run health checks in the background on the endpoints. + Periodically run health checks in the background on the endpoints. Update health_check_results, based on this. """ global health_check_results, llm_model_list, health_check_interval while True: - healthy_endpoints, unhealthy_endpoints = await perform_health_check(model_list=llm_model_list) + healthy_endpoints, unhealthy_endpoints = await perform_health_check( + model_list=llm_model_list + ) # Update the global variable with the health check results health_check_results["healthy_endpoints"] = healthy_endpoints @@ -463,47 +560,53 @@ async def _run_background_health_check(): await asyncio.sleep(health_check_interval) + def load_router_config(router: Optional[litellm.Router], config_file_path: str): global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue config = {} - try: + try: if os.path.exists(config_file_path): user_config_file_path = config_file_path - with open(config_file_path, 'r') as file: + with open(config_file_path, "r") as file: config = yaml.safe_load(file) else: - raise Exception(f"Path to config does not exist, Current working directory: {os.getcwd()}, 'os.path.exists({config_file_path})' returned False") + raise Exception( + f"Path to config does not exist, Current working directory: {os.getcwd()}, 'os.path.exists({config_file_path})' returned False" + ) except Exception as e: raise Exception(f"Exception while reading Config: {e}") - - ## PRINT YAML FOR CONFIRMING IT WORKS + + ## PRINT YAML FOR CONFIRMING IT WORKS printed_yaml = copy.deepcopy(config) printed_yaml.pop("environment_variables", None) - print_verbose(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}") + print_verbose( + f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}" + ) ## ENVIRONMENT VARIABLES - environment_variables = config.get('environment_variables', None) - if environment_variables: - for key, value in environment_variables.items(): + environment_variables = config.get("environment_variables", None) + if environment_variables: + for key, value in environment_variables.items(): os.environ[key] = value ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) - litellm_settings = config.get('litellm_settings', None) - if litellm_settings is None: + litellm_settings = config.get("litellm_settings", None) + if litellm_settings is None: litellm_settings = {} - if litellm_settings: + if litellm_settings: # ANSI escape code for blue text blue_color_code = "\033[94m" reset_color_code = "\033[0m" - for key, value in litellm_settings.items(): + for key, value in litellm_settings.items(): if key == "cache": print(f"{blue_color_code}\nSetting Cache on Proxy") from litellm.caching import Cache + if isinstance(value, dict): cache_type = value.get("type", "redis") else: - cache_type = "redis" # default to using redis on cache + cache_type = "redis" # default to using redis on cache cache_responses = True cache_host = litellm.get_secret("REDIS_HOST", None) cache_port = litellm.get_secret("REDIS_PORT", None) @@ -513,7 +616,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): "type": cache_type, "host": cache_host, "port": cache_port, - "password": cache_password + "password": cache_password, } if "cache_params" in litellm_settings: @@ -525,44 +628,56 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}") print(f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}") print(f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}") - print(f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}") + print( + f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}" + ) print() ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables - litellm.cache = Cache( - **cache_params + litellm.cache = Cache(**cache_params) + print( + f"{blue_color_code}Set Cache on LiteLLM Proxy: {litellm.cache.cache}{reset_color_code} {cache_password}" ) - print(f"{blue_color_code}Set Cache on LiteLLM Proxy: {litellm.cache.cache}{reset_color_code} {cache_password}") elif key == "callbacks": - litellm.callbacks = [get_instance_fn(value=value, config_file_path=config_file_path)] - print_verbose(f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}") + litellm.callbacks = [ + get_instance_fn(value=value, config_file_path=config_file_path) + ] + print_verbose( + f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" + ) elif key == "post_call_rules": - litellm.post_call_rules = [get_instance_fn(value=value, config_file_path=config_file_path)] + litellm.post_call_rules = [ + get_instance_fn(value=value, config_file_path=config_file_path) + ] print(f"litellm.post_call_rules: {litellm.post_call_rules}") elif key == "success_callback": litellm.success_callback = [] - + # intialize success callbacks for callback in value: # user passed custom_callbacks.async_on_succes_logger. They need us to import a function - if "." in callback: + if "." in callback: litellm.success_callback.append(get_instance_fn(value=callback)) # these are litellm callbacks - "langfuse", "sentry", "wandb" else: litellm.success_callback.append(callback) - print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}") + print_verbose( + f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}" + ) elif key == "failure_callback": litellm.failure_callback = [] - + # intialize success callbacks for callback in value: # user passed custom_callbacks.async_on_succes_logger. They need us to import a function - if "." in callback: + if "." in callback: litellm.failure_callback.append(get_instance_fn(value=callback)) # these are litellm callbacks - "langfuse", "sentry", "wandb" else: litellm.failure_callback.append(callback) - print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}") + print_verbose( + f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}" + ) elif key == "cache_params": # this is set in the cache branch # see usage here: https://docs.litellm.ai/docs/proxy/caching @@ -570,48 +685,51 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): else: setattr(litellm, key, value) - - ## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging general_settings = config.get("general_settings", {}) - if general_settings is None: + if general_settings is None: general_settings = {} - if general_settings: + if general_settings: ### LOAD FROM AZURE KEY VAULT ### use_azure_key_vault = general_settings.get("use_azure_key_vault", False) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) ### CONNECT TO DATABASE ### database_url = general_settings.get("database_url", None) - if database_url and database_url.startswith("os.environ/"): + if database_url and database_url.startswith("os.environ/"): database_url = litellm.get_secret(database_url) prisma_setup(database_url=database_url) - ## COST TRACKING ## + ## COST TRACKING ## cost_tracking() ### START REDIS QUEUE ### use_queue = general_settings.get("use_queue", False) ### MASTER KEY ### master_key = general_settings.get("master_key", None) - if master_key and master_key.startswith("os.environ/"): + if master_key and master_key.startswith("os.environ/"): master_key = litellm.get_secret(master_key) #### OpenTelemetry Logging (OTEL) ######## - otel_logging = general_settings.get("otel", False) + otel_logging = general_settings.get("otel", False) if otel_logging == True: print("\nOpenTelemetry Logging Activated") ### CUSTOM API KEY AUTH ### custom_auth = general_settings.get("custom_auth", None) if custom_auth: - user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path) + user_custom_auth = get_instance_fn( + value=custom_auth, config_file_path=config_file_path + ) ### BACKGROUND HEALTH CHECKS ### # Enable background health checks - use_background_health_checks = general_settings.get("background_health_checks", False) + use_background_health_checks = general_settings.get( + "background_health_checks", False + ) health_check_interval = general_settings.get("health_check_interval", 300) router_params: dict = { - "num_retries": 3, - "cache_responses": litellm.cache != None # cache if user passed in cache values + "num_retries": 3, + "cache_responses": litellm.cache + != None, # cache if user passed in cache values } ## MODEL LIST - model_list = config.get('model_list', None) + model_list = config.get("model_list", None) if model_list: router_params["model_list"] = model_list print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m") @@ -619,9 +737,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): print(f"\033[32m {model.get('model_name', '')}\033[0m") litellm_model_name = model["litellm_params"]["model"] litellm_model_api_base = model["litellm_params"].get("api_base", None) - if "ollama" in litellm_model_name and litellm_model_api_base is None: + if "ollama" in litellm_model_name and litellm_model_api_base is None: run_ollama_serve() - + ## ROUTER SETTINGS (e.g. routing_strategy, ...) router_settings = config.get("router_settings", None) if router_settings and isinstance(router_settings, dict): @@ -632,37 +750,39 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): "model_list", } - available_args = [ - x for x in arg_spec.args if x not in exclude_args - ] + available_args = [x for x in arg_spec.args if x not in exclude_args] - for k, v in router_settings.items(): - if k in available_args: + for k, v in router_settings.items(): + if k in available_args: router_params[k] = v - - router = litellm.Router(**router_params) # type:ignore + + router = litellm.Router(**router_params) # type:ignore return router, model_list, general_settings -async def generate_key_helper_fn(duration: Optional[str], - models: list, - aliases: dict, - config: dict, - spend: float, - max_budget: Optional[float]=None, - token: Optional[str]=None, - user_id: Optional[str]=None, - max_parallel_requests: Optional[int]=None, - metadata: Optional[dict] = {},): + +async def generate_key_helper_fn( + duration: Optional[str], + models: list, + aliases: dict, + config: dict, + spend: float, + max_budget: Optional[float] = None, + token: Optional[str] = None, + user_id: Optional[str] = None, + max_parallel_requests: Optional[int] = None, + metadata: Optional[dict] = {}, +): global prisma_client - if prisma_client is None: - raise Exception(f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys ") - + if prisma_client is None: + raise Exception( + f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys " + ) + if token is None: token = f"sk-{secrets.token_urlsafe(16)}" - - def _duration_in_seconds(duration: str): + def _duration_in_seconds(duration: str): match = re.match(r"(\d+)([smhd]?)", duration) if not match: raise ValueError("Invalid duration format") @@ -680,13 +800,13 @@ async def generate_key_helper_fn(duration: Optional[str], return value * 86400 else: raise ValueError("Unsupported duration unit") - - if duration is None: # allow tokens that never expire + + if duration is None: # allow tokens that never expire expires = None - else: + else: duration_s = _duration_in_seconds(duration=duration) expires = datetime.utcnow() + timedelta(seconds=duration_s) - + aliases_json = json.dumps(aliases) config_json = json.dumps(config) metadata_json = json.dumps(metadata) @@ -694,42 +814,51 @@ async def generate_key_helper_fn(duration: Optional[str], try: # Create a new verification token (you may want to enhance this logic based on your needs) verification_token_data = { - "token": token, + "token": token, "expires": expires, "models": models, "aliases": aliases_json, "config": config_json, - "spend": spend, - "user_id": user_id, + "spend": spend, + "user_id": user_id, "max_parallel_requests": max_parallel_requests, "metadata": metadata_json, - "max_budget": max_budget + "max_budget": max_budget, } - new_verification_token = await prisma_client.insert_data(data=verification_token_data) + new_verification_token = await prisma_client.insert_data( + data=verification_token_data + ) except Exception as e: traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) - return {"token": token, "expires": new_verification_token.expires, "user_id": user_id, "max_budget": max_budget} - + return { + "token": token, + "expires": new_verification_token.expires, + "user_id": user_id, + "max_budget": max_budget, + } async def delete_verification_token(tokens: List): global prisma_client - try: - if prisma_client: + try: + if prisma_client: # Assuming 'db' is your Prisma Client instance deleted_tokens = await prisma_client.delete_data(tokens=tokens) - else: + else: raise Exception - except Exception as e: + except Exception as e: traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) return deleted_tokens -def save_worker_config(**data): + +def save_worker_config(**data): import json + os.environ["WORKER_CONFIG"] = json.dumps(data) + def initialize( model=None, alias=None, @@ -746,20 +875,22 @@ def initialize( headers=None, save=False, use_queue=False, - config=None, + config=None, ): global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth generate_feedback_box() user_model = model user_debug = debug - if debug==True: # this needs to be first, so users can see Router init debugg + if debug == True: # this needs to be first, so users can see Router init debugg litellm.set_verbose = True dynamic_config = {"general": {}, user_model: {}} if config: - llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config) - else: + llm_router, llm_model_list, general_settings = load_router_config( + router=llm_router, config_file_path=config + ) + else: # reset auth if config not passed, needed for consecutive tests on proxy - master_key = None + master_key = None user_custom_auth = None if headers: # model-specific param user_headers = headers @@ -791,7 +922,7 @@ def initialize( if max_budget: # litellm-specific param litellm.max_budget = max_budget dynamic_config["general"]["max_budget"] = max_budget - if experimental: + if experimental: pass user_telemetry = telemetry usage_telemetry(feature="local_proxy_server") @@ -810,10 +941,16 @@ def initialize( \n """ print() - print(f"\033[1;34mLiteLLM: Test your local proxy with: \"litellm --test\" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n") - print(f"\033[1;34mLiteLLM: Curl Command Test for your local proxy\n {curl_command} \033[0m\n") + print( + f'\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n' + ) + print( + f"\033[1;34mLiteLLM: Curl Command Test for your local proxy\n {curl_command} \033[0m\n" + ) print("\033[1;34mDocs: https://docs.litellm.ai/docs/simple_proxy\033[0m\n") print(f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:8000 \033[0m\n") + + # for streaming def data_generator(response): print_verbose("inside generator") @@ -824,17 +961,20 @@ def data_generator(response): except: yield f"data: {json.dumps(chunk)}\n\n" + async def async_data_generator(response, user_api_key_dict): print_verbose("inside generator") - try: + try: async for chunk in response: print_verbose(f"returned chunk: {chunk}") try: yield f"data: {json.dumps(chunk.dict())}\n\n" except Exception as e: yield f"data: {str(e)}\n\n" - except Exception as e: + except Exception as e: yield f"data: {str(e)}\n\n" + + def get_litellm_model_info(model: dict = {}): model_info = model.get("model_info", {}) model_to_lookup = model.get("litellm_params", {}).get("model", None) @@ -847,13 +987,14 @@ def get_litellm_model_info(model: dict = {}): # this should not block returning on /model/info # if litellm does not have info on the model it should return {} return {} - + + @router.on_event("startup") async def startup_event(): global prisma_client, master_key, use_background_health_checks import json - ### LOAD CONFIG ### + ### LOAD CONFIG ### worker_config = litellm.get_secret("WORKER_CONFIG") print_verbose(f"worker_config: {worker_config}") # check if it's a valid file path @@ -863,42 +1004,47 @@ async def startup_event(): # if not, assume it's a json string worker_config = json.loads(os.getenv("WORKER_CONFIG")) initialize(**worker_config) - - - proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made + + proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made if use_background_health_checks: - asyncio.create_task(_run_background_health_check()) # start the background health check coroutine. + asyncio.create_task( + _run_background_health_check() + ) # start the background health check coroutine. print_verbose(f"prisma client - {prisma_client}") - if prisma_client: + if prisma_client: await prisma_client.connect() - - if prisma_client is not None and master_key is not None: + + if prisma_client is not None and master_key is not None: # add master key to db - await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key) + await generate_key_helper_fn( + duration=None, models=[], aliases={}, config={}, spend=0, token=master_key + ) #### API ENDPOINTS #### @router.get("/v1/models", dependencies=[Depends(user_api_key_auth)]) -@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list +@router.get( + "/models", dependencies=[Depends(user_api_key_auth)] +) # if project requires model list def model_list(): - global llm_model_list, general_settings + global llm_model_list, general_settings all_models = [] if general_settings.get("infer_model_from_keys", False): all_models = litellm.utils.get_valid_models() - if llm_model_list: + if llm_model_list: all_models = list(set(all_models + [m["model_name"] for m in llm_model_list])) if user_model is not None: all_models += [user_model] print_verbose(f"all_models: {all_models}") - ### CHECK OLLAMA MODELS ### + ### CHECK OLLAMA MODELS ### try: response = requests.get("http://0.0.0.0:11434/api/tags") models = response.json()["models"] ollama_models = ["ollama/" + m["name"].replace(":latest", "") for m in models] all_models.extend(ollama_models) - except Exception as e: + except Exception as e: pass return dict( data=[ @@ -913,25 +1059,39 @@ def model_list(): object="list", ) -@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]) -@router.post("/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]) -@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]) -async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): + +@router.post( + "/v1/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"] +) +@router.post( + "/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"] +) +@router.post( + "/engines/{model:path}/completions", + dependencies=[Depends(user_api_key_auth)], + tags=["completions"], +) +async def completion( + request: Request, + model: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + background_tasks: BackgroundTasks = BackgroundTasks(), +): global user_temperature, user_request_timeout, user_max_tokens, user_api_base - try: + try: body = await request.body() body_str = body.decode() try: data = ast.literal_eval(body_str) - except: + except: data = json.loads(body_str) - + data["user"] = data.get("user", user_api_key_dict.user_id) data["model"] = ( - general_settings.get("completion_model", None) # server default - or user_model # model name passed via cli args - or model # for azure deployments - or data["model"] # default passed in http request + general_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or model # for azure deployments + or data["model"] # default passed in http request ) if user_model: data["model"] = user_model @@ -940,41 +1100,71 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["headers"] = dict(request.headers) else: - data["metadata"] = {"user_api_key": user_api_key_dict.api_key, "user_api_key_user_id": user_api_key_dict.user_id} + data["metadata"] = { + "user_api_key": user_api_key_dict.api_key, + "user_api_key_user_id": user_api_key_dict.user_id, + } data["metadata"]["headers"] = dict(request.headers) # override with user settings, these are params passed via cli - if user_temperature: + if user_temperature: data["temperature"] = user_temperature if user_request_timeout: data["request_timeout"] = user_request_timeout - if user_max_tokens: + if user_max_tokens: data["max_tokens"] = user_max_tokens - if user_api_base: + if user_api_base: data["api_base"] = user_api_base ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion") + data = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, data=data, call_type="completion" + ) ### ROUTE THE REQUEST ### - router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] - if llm_router is not None and data["model"] in router_model_names: # model in router model list - response = await llm_router.atext_completion(**data) - elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router - response = await llm_router.atext_completion(**data, specific_deployment = True) - elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list response = await llm_router.atext_completion(**data) - else: # router is not set + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.atext_completion( + **data, specific_deployment=True + ) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + response = await llm_router.atext_completion(**data) + else: # router is not set response = await litellm.atext_completion(**data) - + print(f"final response: {response}") - if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses - return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream') - - background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL + if ( + "stream" in data and data["stream"] == True + ): # use generate_responses to stream responses + return StreamingResponse( + async_data_generator( + user_api_key_dict=user_api_key_dict, response=response + ), + media_type="text/event-stream", + ) + + background_tasks.add_task( + log_input_output, request, response + ) # background task for logging to OTEL return response - except Exception as e: + except Exception as e: print(f"EXCEPTION RAISED IN PROXY MAIN.PY") - print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`") + print( + f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" + ) traceback.print_exc() error_traceback = traceback.format_exc() error_msg = f"{str(e)}\n\n{error_traceback}" @@ -982,23 +1172,38 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key status = e.status_code # type: ignore except: status = 500 - raise HTTPException( - status_code=status, - detail=error_msg - ) + raise HTTPException(status_code=status, detail=error_msg) -@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) -@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) -@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint -async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): + +@router.post( + "/v1/chat/completions", + dependencies=[Depends(user_api_key_auth)], + tags=["chat/completions"], +) +@router.post( + "/chat/completions", + dependencies=[Depends(user_api_key_auth)], + tags=["chat/completions"], +) +@router.post( + "/openai/deployments/{model:path}/chat/completions", + dependencies=[Depends(user_api_key_auth)], + tags=["chat/completions"], +) # azure compatible endpoint +async def chat_completion( + request: Request, + model: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + background_tasks: BackgroundTasks = BackgroundTasks(), +): global general_settings, user_debug, proxy_logging_obj - try: + try: data = {} body = await request.body() body_str = body.decode() try: data = ast.literal_eval(body_str) - except: + except: data = json.loads(body_str) # Include original request and headers in the data @@ -1006,15 +1211,15 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap "url": str(request.url), "method": request.method, "headers": dict(request.headers), - "body": copy.copy(data) # use copy instead of deepcopy + "body": copy.copy(data), # use copy instead of deepcopy } print_verbose(f"receiving data: {data}") data["model"] = ( - general_settings.get("completion_model", None) # server default - or user_model # model name passed via cli args - or model # for azure deployments - or data["model"] # default passed in http request + general_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or model # for azure deployments + or data["model"] # default passed in http request ) # users can pass in 'user' param to /chat/completions. Don't override it @@ -1031,43 +1236,74 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["headers"] = dict(request.headers) data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - + global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli - if user_temperature: + if user_temperature: data["temperature"] = user_temperature if user_request_timeout: data["request_timeout"] = user_request_timeout - if user_max_tokens: + if user_max_tokens: data["max_tokens"] = user_max_tokens - if user_api_base: + if user_api_base: data["api_base"] = user_api_base ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion") + data = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, data=data, call_type="completion" + ) ### ROUTE THE REQUEST ### - router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] - if llm_router is not None and data["model"] in router_model_names: # model in router model list - response = await llm_router.acompletion(**data) - elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router - response = await llm_router.acompletion(**data, specific_deployment = True) - elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list response = await llm_router.acompletion(**data) - else: # router is not set + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.acompletion(**data, specific_deployment=True) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + response = await llm_router.acompletion(**data) + else: # router is not set response = await litellm.acompletion(**data) - + print(f"final response: {response}") - if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses - return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream') - - background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL + if ( + "stream" in data and data["stream"] == True + ): # use generate_responses to stream responses + return StreamingResponse( + async_data_generator( + user_api_key_dict=user_api_key_dict, response=response + ), + media_type="text/event-stream", + ) + + background_tasks.add_task( + log_input_output, request, response + ) # background task for logging to OTEL return response except Exception as e: - await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e) - print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`") - router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] - if llm_router is not None and data.get("model", "") in router_model_names: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e + ) + print( + f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" + ) + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and data.get("model", "") in router_model_names: print("Results from router") print("\nRouter stats") print("\nTotal Calls made") @@ -1079,47 +1315,59 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap print("\nFail Calls made") for key, value in llm_router.fail_calls.items(): print(f"{key}: {value}") - if user_debug: + if user_debug: traceback.print_exc() - + if isinstance(e, HTTPException): raise e else: error_traceback = traceback.format_exc() error_msg = f"{str(e)}\n\n{error_traceback}" try: - status = e.status_code # type: ignore + status = e.status_code # type: ignore except: status = 500 - raise HTTPException( - status_code=status, - detail=error_msg - ) + raise HTTPException(status_code=status, detail=error_msg) -@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["embeddings"]) -@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["embeddings"]) -async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): + +@router.post( + "/v1/embeddings", + dependencies=[Depends(user_api_key_auth)], + response_class=ORJSONResponse, + tags=["embeddings"], +) +@router.post( + "/embeddings", + dependencies=[Depends(user_api_key_auth)], + response_class=ORJSONResponse, + tags=["embeddings"], +) +async def embeddings( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + background_tasks: BackgroundTasks = BackgroundTasks(), +): global proxy_logging_obj - try: + try: # Use orjson to parse JSON data, orjson speeds up requests significantly body = await request.body() data = orjson.loads(body) - # Include original request and headers in the data + # Include original request and headers in the data data["proxy_server_request"] = { "url": str(request.url), "method": request.method, "headers": dict(request.headers), - "body": copy.copy(data) # use copy instead of deepcopy + "body": copy.copy(data), # use copy instead of deepcopy } if data.get("user", None) is None and user_api_key_dict.user_id is not None: data["user"] = user_api_key_dict.user_id data["model"] = ( - general_settings.get("embedding_model", None) # server default - or user_model # model name passed via cli args - or data["model"] # default passed in http request + general_settings.get("embedding_model", None) # server default + or user_model # model name passed via cli args + or data["model"] # default passed in http request ) if user_model: data["model"] = user_model @@ -1132,38 +1380,67 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen data["metadata"]["headers"] = dict(request.headers) data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] - if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if ( + "input" in data + and isinstance(data["input"], list) + and isinstance(data["input"][0], list) + and isinstance(data["input"][0][0], int) + ): # check if array of tokens passed in # check if non-openai/azure model called - e.g. for langchain integration - if llm_model_list is not None and data["model"] in router_model_names: - for m in llm_model_list: - if m["model_name"] == data["model"] and (m["litellm_params"]["model"] in litellm.open_ai_embedding_models - or m["litellm_params"]["model"].startswith("azure/")): + if llm_model_list is not None and data["model"] in router_model_names: + for m in llm_model_list: + if m["model_name"] == data["model"] and ( + m["litellm_params"]["model"] in litellm.open_ai_embedding_models + or m["litellm_params"]["model"].startswith("azure/") + ): pass - else: + else: # non-openai/azure embedding model called with token input input_list = [] - for i in data["input"]: - input_list.append(litellm.decode(model="gpt-3.5-turbo", tokens=i)) + for i in data["input"]: + input_list.append( + litellm.decode(model="gpt-3.5-turbo", tokens=i) + ) data["input"] = input_list break - + ### CALL HOOKS ### - modify incoming data / reject request before calling the model - data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings") + data = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" + ) ## ROUTE TO CORRECT ENDPOINT ## - if llm_router is not None and data["model"] in router_model_names: # model in router model list + if ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list response = await llm_router.aembedding(**data) - elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router - response = await llm_router.aembedding(**data, specific_deployment = True) - elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias - response = await llm_router.aembedding(**data) # ensure this goes the llm_router, router will do the correct alias mapping + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.aembedding(**data, specific_deployment=True) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + response = await llm_router.aembedding( + **data + ) # ensure this goes the llm_router, router will do the correct alias mapping else: response = await litellm.aembedding(**data) - background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL + background_tasks.add_task( + log_input_output, request, response + ) # background task for logging to OTEL return response except Exception as e: - await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e) + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e + ) traceback.print_exc() if isinstance(e, HTTPException): raise e @@ -1171,39 +1448,50 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen error_traceback = traceback.format_exc() error_msg = f"{str(e)}\n\n{error_traceback}" try: - status = e.status_code # type: ignore + status = e.status_code # type: ignore except: status = 500 - raise HTTPException( - status_code=status, - detail=error_msg - ) + raise HTTPException(status_code=status, detail=error_msg) -@router.post("/v1/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"]) -@router.post("/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"]) -async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): +@router.post( + "/v1/images/generations", + dependencies=[Depends(user_api_key_auth)], + response_class=ORJSONResponse, + tags=["image generation"], +) +@router.post( + "/images/generations", + dependencies=[Depends(user_api_key_auth)], + response_class=ORJSONResponse, + tags=["image generation"], +) +async def image_generation( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + background_tasks: BackgroundTasks = BackgroundTasks(), +): global proxy_logging_obj - try: + try: # Use orjson to parse JSON data, orjson speeds up requests significantly body = await request.body() data = orjson.loads(body) - # Include original request and headers in the data + # Include original request and headers in the data data["proxy_server_request"] = { "url": str(request.url), "method": request.method, "headers": dict(request.headers), - "body": copy.copy(data) # use copy instead of deepcopy + "body": copy.copy(data), # use copy instead of deepcopy } if data.get("user", None) is None and user_api_key_dict.user_id is not None: data["user"] = user_api_key_dict.user_id data["model"] = ( - general_settings.get("image_generation_model", None) # server default - or user_model # model name passed via cli args - or data["model"] # default passed in http request + general_settings.get("image_generation_model", None) # server default + or user_model # model name passed via cli args + or data["model"] # default passed in http request ) if user_model: data["model"] = user_model @@ -1216,24 +1504,46 @@ async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth = data["metadata"]["headers"] = dict(request.headers) data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] - + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + ### CALL HOOKS ### - modify incoming data / reject request before calling the model - data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings") + data = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" + ) ## ROUTE TO CORRECT ENDPOINT ## - if llm_router is not None and data["model"] in router_model_names: # model in router model list + if ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list response = await llm_router.aimage_generation(**data) - elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router - response = await llm_router.aimage_generation(**data, specific_deployment = True) - elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias - response = await llm_router.aimage_generation(**data) # ensure this goes the llm_router, router will do the correct alias mapping + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.aimage_generation( + **data, specific_deployment=True + ) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + response = await llm_router.aimage_generation( + **data + ) # ensure this goes the llm_router, router will do the correct alias mapping else: response = await litellm.aimage_generation(**data) - background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL + background_tasks.add_task( + log_input_output, request, response + ) # background task for logging to OTEL return response except Exception as e: - await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e) + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e + ) traceback.print_exc() if isinstance(e, HTTPException): raise e @@ -1241,19 +1551,28 @@ async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth = error_traceback = traceback.format_exc() error_msg = f"{str(e)}\n\n{error_traceback}" try: - status = e.status_code # type: ignore + status = e.status_code # type: ignore except: status = 500 - raise HTTPException( - status_code=status, - detail=error_msg - ) -#### KEY MANAGEMENT #### + raise HTTPException(status_code=status, detail=error_msg) -@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse) -async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorization: Optional[str] = Header(None)): + +#### KEY MANAGEMENT #### + + +@router.post( + "/key/generate", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], + response_model=GenerateKeyResponse, +) +async def generate_key_fn( + request: Request, + data: GenerateKeyRequest, + Authorization: Optional[str] = Header(None), +): """ - Generate an API key based on the provided data. + Generate an API key based on the provided data. Docs: https://docs.litellm.ai/docs/proxy/virtual_keys @@ -1261,121 +1580,161 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)** - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models - - config: Optional[dict] - any key-specific configs, overrides config in config.yaml + - config: Optional[dict] - any key-specific configs, overrides config in config.yaml - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. - - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } Returns: - - key: (str) The generated api key + - key: (str) The generated api key - expires: (datetime) Datetime object for when key expires. - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. """ - data_json = data.json() # type: ignore + data_json = data.json() # type: ignore response = await generate_key_helper_fn(**data_json) - return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"]) + return GenerateKeyResponse( + key=response["token"], expires=response["expires"], user_id=response["user_id"] + ) -@router.post("/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) + +@router.post( + "/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) async def update_key_fn(request: Request, data: UpdateKeyRequest): """ Update an existing key """ global prisma_client - try: + try: data_json: dict = data.json() key = data_json.pop("key") - # get the row from db - if prisma_client is None: + # get the row from db + if prisma_client is None: raise Exception("Not connected to DB!") - + non_default_values = {k: v for k, v in data_json.items() if v is not None} - response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key}) + response = await prisma_client.update_data( + token=key, data={**non_default_values, "token": key} + ) return {"key": key, **non_default_values} - # update based on remaining passed in values - except Exception as e: + # update based on remaining passed in values + except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"error": str(e)}, ) -@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) -async def delete_key_fn(request: Request, data: DeleteKeyRequest): - try: + +@router.post( + "/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +async def delete_key_fn(request: Request, data: DeleteKeyRequest): + try: keys = data.keys - + deleted_keys = await delete_verification_token(tokens=keys) assert len(keys) == deleted_keys return {"deleted_keys": keys} - except Exception as e: + except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"error": str(e)}, ) -@router.get("/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) -async def info_key_fn(key: str = fastapi.Query(..., description="Key in the request parameters")): + +@router.get( + "/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +async def info_key_fn( + key: str = fastapi.Query(..., description="Key in the request parameters") +): global prisma_client - try: - if prisma_client is None: - raise Exception(f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys") + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) key_info = await prisma_client.get_data(token=key) return {"key": key, "info": key_info} - except Exception as e: + except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"error": str(e)}, ) -#### USER MANAGEMENT #### -@router.post("/user/new", tags=["user management"], dependencies=[Depends(user_api_key_auth)], response_model=NewUserResponse) +#### USER MANAGEMENT #### + + +@router.post( + "/user/new", + tags=["user management"], + dependencies=[Depends(user_api_key_auth)], + response_model=NewUserResponse, +) async def new_user(data: NewUserRequest): """ - Use this to create a new user with a budget. + Use this to create a new user with a budget. Returns user id, budget + new key. Parameters: - - user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated. + - user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated. - max_budget: Optional[float] - Specify max budget for a given user. - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)** - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models - - config: Optional[dict] - any key-specific configs, overrides config in config.yaml + - config: Optional[dict] - any key-specific configs, overrides config in config.yaml - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. - - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } Returns: - - key: (str) The generated api key + - key: (str) The generated api key - expires: (datetime) Datetime object for when key expires. - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. - max_budget: (float|None) Max budget for given user. """ - data_json = data.json() # type: ignore + data_json = data.json() # type: ignore response = await generate_key_helper_fn(**data_json) - return NewUserResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"], max_budget=response["max_budget"]) + return NewUserResponse( + key=response["token"], + expires=response["expires"], + user_id=response["user_id"], + max_budget=response["max_budget"], + ) - -@router.post("/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]) +@router.post( + "/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)] +) async def user_info(request: Request): """ - [TODO]: Use this to get user information. (user row + all user key info) + [TODO]: Use this to get user information. (user row + all user key info) """ pass -@router.post("/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)]) + +@router.post( + "/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)] +) async def user_update(request: Request): """ [TODO]: Use this to update user budget """ pass - -#### MODEL MANAGEMENT #### - + + +#### MODEL MANAGEMENT #### + + #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 -@router.post("/model/new", description="Allows adding new models to the model list in the config.yaml", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) +@router.post( + "/model/new", + description="Allows adding new models to the model list in the config.yaml", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) async def add_new_model(model_params: ModelParams): global llm_router, llm_model_list, general_settings, user_config_file_path try: @@ -1384,26 +1743,29 @@ async def add_new_model(model_params: ModelParams): if os.path.exists(f"{user_config_file_path}"): with open(f"{user_config_file_path}", "r") as config_file: config = yaml.safe_load(config_file) - else: - config = {"model_list": []} - + else: + config = {"model_list": []} + print_verbose(f"Loaded config: {config}") # Add the new model to the config model_info = model_params.model_info.json() model_info = {k: v for k, v in model_info.items() if v is not None} - config['model_list'].append({ - 'model_name': model_params.model_name, - 'litellm_params': model_params.litellm_params, - 'model_info': model_info - }) + config["model_list"].append( + { + "model_name": model_params.model_name, + "litellm_params": model_params.litellm_params, + "model_info": model_info, + } + ) # Save the updated config with open(f"{user_config_file_path}", "w") as config_file: yaml.dump(config, config_file, default_flow_style=False) - # update Router - llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=user_config_file_path) - + # update Router + llm_router, llm_model_list, general_settings = load_router_config( + router=llm_router, config_file_path=user_config_file_path + ) return {"message": "Model added successfully"} @@ -1411,14 +1773,20 @@ async def add_new_model(model_params: ModelParams): traceback.print_exc() raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") + #### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info -@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) +@router.get( + "/model/info", + description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) async def model_info_v1(request: Request): global llm_model_list, general_settings, user_config_file_path # Load existing config with open(f"{user_config_file_path}", "r") as config_file: config = yaml.safe_load(config_file) - all_models = config['model_list'] + all_models = config["model_list"] for model in all_models: # provided model_info in config.yaml model_info = model.get("model_info", {}) @@ -1426,30 +1794,33 @@ async def model_info_v1(request: Request): # read litellm model_prices_and_context_window.json to get the following: # input_cost_per_token, output_cost_per_token, max_tokens litellm_model_info = get_litellm_model_info(model=model) - for k, v in litellm_model_info.items(): - if k not in model_info: + for k, v in litellm_model_info.items(): + if k not in model_info: model_info[k] = v model["model_info"] = model_info # don't return the api key model["litellm_params"].pop("api_key", None) print_verbose(f"all_models: {all_models}") - return { - "data": all_models - } + return {"data": all_models} #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933 -@router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) +@router.get( + "/v1/model/info", + description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) async def model_info(request: Request): global llm_model_list, general_settings, user_config_file_path # Load existing config with open(f"{user_config_file_path}", "r") as config_file: config = yaml.safe_load(config_file) - all_models = config['model_list'] + all_models = config["model_list"] for model in all_models: - # get the model cost map info + # get the model cost map info ## make an api call data = copy.deepcopy(model["litellm_params"]) data["messages"] = [{"role": "user", "content": "Hey, how's it going?"}] @@ -1460,13 +1831,13 @@ async def model_info(request: Request): print(f"response model: {response_model}; response - {response}") litellm_model_info = litellm.get_model_info(response_model) model_info = model.get("model_info", {}) - for k, v in litellm_model_info.items(): - if k not in model_info: + for k, v in litellm_model_info.items(): + if k not in model_info: model_info[k] = v model["model_info"] = model_info # don't return the api key model["litellm_params"].pop("api_key", None) - + # all_models = list(set([m["model_name"] for m in llm_model_list])) print_verbose(f"all_models: {all_models}") return dict( @@ -1482,8 +1853,14 @@ async def model_info(request: Request): object="list", ) + #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 -@router.post("/model/delete", description="Allows deleting models in the model list in the config.yaml", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) +@router.post( + "/model/delete", + description="Allows deleting models in the model list in the config.yaml", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) async def delete_model(model_info: ModelInfoDelete): global llm_router, llm_model_list, general_settings, user_config_file_path try: @@ -1494,27 +1871,33 @@ async def delete_model(model_info: ModelInfoDelete): config = yaml.safe_load(config_file) # If model_list is not in the config, nothing can be deleted - if 'model_list' not in config: - raise HTTPException(status_code=404, detail="No model list available in the config.") + if "model_list" not in config: + raise HTTPException( + status_code=404, detail="No model list available in the config." + ) # Check if the model with the specified model_id exists model_to_delete = None - for model in config['model_list']: - if model.get('model_info', {}).get('id', None) == model_info.id: + for model in config["model_list"]: + if model.get("model_info", {}).get("id", None) == model_info.id: model_to_delete = model break # If the model was not found, return an error if model_to_delete is None: - raise HTTPException(status_code=404, detail="Model with given model_id not found.") + raise HTTPException( + status_code=404, detail="Model with given model_id not found." + ) # Remove model from the list and save the updated config - config['model_list'].remove(model_to_delete) + config["model_list"].remove(model_to_delete) with open(user_config_file_path, "w") as config_file: yaml.dump(config, config_file, default_flow_style=False) # Update Router - llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=user_config_file_path) + llm_router, llm_model_list, general_settings = load_router_config( + router=llm_router, config_file_path=user_config_file_path + ) return {"message": "Model deleted successfully"} @@ -1524,43 +1907,76 @@ async def delete_model(model_info: ModelInfoDelete): except Exception as e: raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") -#### EXPERIMENTAL QUEUING #### + +#### EXPERIMENTAL QUEUING #### async def _litellm_chat_completions_worker(data, user_api_key_dict): """ - worker to make litellm completions calls + worker to make litellm completions calls """ - while True: - try: + while True: + try: ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion") + data = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, data=data, call_type="completion" + ) print(f"_litellm_chat_completions_worker started") ### ROUTE THE REQUEST ### - router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] - if llm_router is not None and data["model"] in router_model_names: # model in router model list - response = await llm_router.acompletion(**data) - elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router - response = await llm_router.acompletion(**data, specific_deployment = True) - elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list response = await llm_router.acompletion(**data) - else: # router is not set + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + response = await llm_router.acompletion( + **data, specific_deployment=True + ) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + response = await llm_router.acompletion(**data) + else: # router is not set response = await litellm.acompletion(**data) - + print(f"final response: {response}") return response - except HTTPException as e: - print(f"EXCEPTION RAISED IN _litellm_chat_completions_worker - {e.status_code}; {e.detail}") - if e.status_code == 429 and "Max parallel request limit reached" in e.detail: + except HTTPException as e: + print( + f"EXCEPTION RAISED IN _litellm_chat_completions_worker - {e.status_code}; {e.detail}" + ) + if ( + e.status_code == 429 + and "Max parallel request limit reached" in e.detail + ): print(f"Max parallel request limit reached!") - timeout = litellm._calculate_retry_after(remaining_retries=3, max_retries=3, min_timeout=1) + timeout = litellm._calculate_retry_after( + remaining_retries=3, max_retries=3, min_timeout=1 + ) await asyncio.sleep(timeout) - else: - raise e + else: + raise e -@router.post("/queue/chat/completions", tags=["experimental"], dependencies=[Depends(user_api_key_auth)]) -async def async_queue_request(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): - global general_settings, user_debug, proxy_logging_obj +@router.post( + "/queue/chat/completions", + tags=["experimental"], + dependencies=[Depends(user_api_key_auth)], +) +async def async_queue_request( + request: Request, + model: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + background_tasks: BackgroundTasks = BackgroundTasks(), +): + global general_settings, user_debug, proxy_logging_obj """ v2 attempt at a background worker to handle queuing. @@ -1568,24 +1984,24 @@ async def async_queue_request(request: Request, model: Optional[str] = None, use Now using a FastAPI background task + /chat/completions compatible endpoint """ - try: + try: data = {} - data = await request.json() # type: ignore + data = await request.json() # type: ignore # Include original request and headers in the data data["proxy_server_request"] = { "url": str(request.url), "method": request.method, "headers": dict(request.headers), - "body": copy.copy(data) # use copy instead of deepcopy + "body": copy.copy(data), # use copy instead of deepcopy } print_verbose(f"receiving data: {data}") data["model"] = ( - general_settings.get("completion_model", None) # server default - or user_model # model name passed via cli args - or model # for azure deployments - or data["model"] # default passed in http request + general_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or model # for azure deployments + or data["model"] # default passed in http request ) # users can pass in 'user' param to /chat/completions. Don't override it @@ -1602,54 +2018,71 @@ async def async_queue_request(request: Request, model: Optional[str] = None, use data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["headers"] = dict(request.headers) data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - + global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli - if user_temperature: + if user_temperature: data["temperature"] = user_temperature if user_request_timeout: data["request_timeout"] = user_request_timeout - if user_max_tokens: + if user_max_tokens: data["max_tokens"] = user_max_tokens - if user_api_base: + if user_api_base: data["api_base"] = user_api_base - response = await asyncio.wait_for(_litellm_chat_completions_worker(data=data, user_api_key_dict=user_api_key_dict), timeout=litellm.request_timeout) + response = await asyncio.wait_for( + _litellm_chat_completions_worker( + data=data, user_api_key_dict=user_api_key_dict + ), + timeout=litellm.request_timeout, + ) - if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses - return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream') - - background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL + if ( + "stream" in data and data["stream"] == True + ): # use generate_responses to stream responses + return StreamingResponse( + async_data_generator( + user_api_key_dict=user_api_key_dict, response=response + ), + media_type="text/event-stream", + ) + + background_tasks.add_task( + log_input_output, request, response + ) # background task for logging to OTEL return response - except Exception as e: - await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e) + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e + ) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"error": str(e)}, ) - - + + @router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)]) async def retrieve_server_log(request: Request): filepath = os.path.expanduser("~/.ollama/logs/server.log") return FileResponse(filepath) -#### BASIC ENDPOINTS #### +#### BASIC ENDPOINTS #### + @router.get("/config/yaml", tags=["config.yaml"]) -async def config_yaml_endpoint(config_info: ConfigYAML): +async def config_yaml_endpoint(config_info: ConfigYAML): """ - This is a mock endpoint, to show what you can set in config.yaml details in the Swagger UI. + This is a mock endpoint, to show what you can set in config.yaml details in the Swagger UI. Parameters: The config.yaml object has the following attributes: - - **model_list**: *Optional[List[ModelParams]]* - A list of supported models on the server, along with model-specific configurations. ModelParams includes "model_name" (name of the model), "litellm_params" (litellm-specific parameters for the model), and "model_info" (additional info about the model such as id, mode, cost per token, etc). + - **model_list**: *Optional[List[ModelParams]]* - A list of supported models on the server, along with model-specific configurations. ModelParams includes "model_name" (name of the model), "litellm_params" (litellm-specific parameters for the model), and "model_info" (additional info about the model such as id, mode, cost per token, etc). - **litellm_settings**: *Optional[dict]*: Settings for the litellm module. You can specify multiple properties like "drop_params", "set_verbose", "api_base", "cache". - - - **general_settings**: *Optional[ConfigGeneralSettings]*: General settings for the server like "completion_model" (default model for chat completion calls), "use_azure_key_vault" (option to load keys from azure key vault), "master_key" (key required for all calls to proxy), and others. + + - **general_settings**: *Optional[ConfigGeneralSettings]*: General settings for the server like "completion_model" (default model for chat completion calls), "use_azure_key_vault" (option to load keys from azure key vault), "master_key" (key required for all calls to proxy), and others. Please, refer to each class's description for a better understanding of the specific attributes within them. @@ -1659,7 +2092,7 @@ async def config_yaml_endpoint(config_info: ConfigYAML): @router.get("/test", tags=["health"]) -async def test_endpoint(request: Request): +async def test_endpoint(request: Request): """ A test endpoint that pings the proxy server to check if it's healthy. @@ -1672,12 +2105,18 @@ async def test_endpoint(request: Request): # ping the proxy server to check if its healthy return {"route": request.url.path} + @router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)]) -async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")): +async def health_endpoint( + request: Request, + model: Optional[str] = fastapi.Query( + None, description="Specify the model name (optional)" + ), +): """ Check the health of all the endpoints in config.yaml - To run health checks in the background, add this to config.yaml: + To run health checks in the background, add this to config.yaml: ``` general_settings: # ... other settings @@ -1687,10 +2126,12 @@ async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query """ global health_check_results, use_background_health_checks, user_model - if llm_model_list is None: + if llm_model_list is None: # if no router set, check if user set a model using litellm --model ollama/llama2 if user_model is not None: - healthy_endpoints, unhealthy_endpoints = await perform_health_check(model_list=[], cli_model=user_model) + healthy_endpoints, unhealthy_endpoints = await perform_health_check( + model_list=[], cli_model=user_model + ) return { "healthy_endpoints": healthy_endpoints, "unhealthy_endpoints": unhealthy_endpoints, @@ -1701,11 +2142,13 @@ async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"error": "Model list not initialized"}, ) - + if use_background_health_checks: return health_check_results else: - healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model) + healthy_endpoints, unhealthy_endpoints = await perform_health_check( + llm_model_list, model + ) return { "healthy_endpoints": healthy_endpoints, @@ -1714,10 +2157,12 @@ async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query "unhealthy_count": len(unhealthy_endpoints), } + @router.get("/") async def home(request: Request): return "LiteLLM: RUNNING" + @router.get("/routes") async def get_routes(): """ @@ -1742,13 +2187,14 @@ async def shutdown_event(): if prisma_client: print("Disconnecting from Prisma") await prisma_client.disconnect() - - ## RESET CUSTOM VARIABLES ## + + ## RESET CUSTOM VARIABLES ## cleanup_router_config_variables() + def cleanup_router_config_variables(): global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval - + # Set all variables to None master_key = None user_config_file_path = None diff --git a/litellm/proxy/queue/celery_app.py b/litellm/proxy/queue/celery_app.py index b9006f13e..9f59b6edf 100644 --- a/litellm/proxy/queue/celery_app.py +++ b/litellm/proxy/queue/celery_app.py @@ -1,71 +1,77 @@ -from dotenv import load_dotenv -load_dotenv() -import json, subprocess -import psutil # Import the psutil library -import atexit -try: - ### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both - from celery import Celery - import redis -except: - import sys +# from dotenv import load_dotenv - subprocess.check_call( - [ - sys.executable, - "-m", - "pip", - "install", - "redis", - "celery" - ] - ) +# load_dotenv() +# import json, subprocess +# import psutil # Import the psutil library +# import atexit -import time -import sys, os -sys.path.insert( - 0, os.path.abspath("../../..") -) # Adds the parent directory to the system path - for litellm local dev -import litellm +# try: +# ### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both +# from celery import Celery +# import redis +# except: +# import sys -# Redis connection setup -pool = redis.ConnectionPool(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"), db=0, max_connections=5) -redis_client = redis.Redis(connection_pool=pool) +# subprocess.check_call([sys.executable, "-m", "pip", "install", "redis", "celery"]) -# Celery setup -celery_app = Celery('tasks', broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}", backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}") -celery_app.conf.update( - broker_pool_limit = None, - broker_transport_options = {'connection_pool': pool}, - result_backend_transport_options = {'connection_pool': pool}, -) +# import time +# import sys, os + +# sys.path.insert( +# 0, os.path.abspath("../../..") +# ) # Adds the parent directory to the system path - for litellm local dev +# import litellm + +# # Redis connection setup +# pool = redis.ConnectionPool( +# host=os.getenv("REDIS_HOST"), +# port=os.getenv("REDIS_PORT"), +# password=os.getenv("REDIS_PASSWORD"), +# db=0, +# max_connections=5, +# ) +# redis_client = redis.Redis(connection_pool=pool) + +# # Celery setup +# celery_app = Celery( +# "tasks", +# broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}", +# backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}", +# ) +# celery_app.conf.update( +# broker_pool_limit=None, +# broker_transport_options={"connection_pool": pool}, +# result_backend_transport_options={"connection_pool": pool}, +# ) -# Celery task -@celery_app.task(name='process_job', max_retries=3) -def process_job(*args, **kwargs): - try: - llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore - response = llm_router.completion(*args, **kwargs) # type: ignore - if isinstance(response, litellm.ModelResponse): - response = response.model_dump_json() - return json.loads(response) - return str(response) - except Exception as e: - raise e +# # Celery task +# @celery_app.task(name="process_job", max_retries=3) +# def process_job(*args, **kwargs): +# try: +# llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore +# response = llm_router.completion(*args, **kwargs) # type: ignore +# if isinstance(response, litellm.ModelResponse): +# response = response.model_dump_json() +# return json.loads(response) +# return str(response) +# except Exception as e: +# raise e -# Ensure Celery workers are terminated when the script exits -def cleanup(): - try: - # Get a list of all running processes - for process in psutil.process_iter(attrs=['pid', 'name']): - # Check if the process is a Celery worker process - if process.info['name'] == 'celery': - print(f"Terminating Celery worker with PID {process.info['pid']}") - # Terminate the Celery worker process - psutil.Process(process.info['pid']).terminate() - except Exception as e: - print(f"Error during cleanup: {e}") -# Register the cleanup function to run when the script exits -atexit.register(cleanup) \ No newline at end of file +# # Ensure Celery workers are terminated when the script exits +# def cleanup(): +# try: +# # Get a list of all running processes +# for process in psutil.process_iter(attrs=["pid", "name"]): +# # Check if the process is a Celery worker process +# if process.info["name"] == "celery": +# print(f"Terminating Celery worker with PID {process.info['pid']}") +# # Terminate the Celery worker process +# psutil.Process(process.info["pid"]).terminate() +# except Exception as e: +# print(f"Error during cleanup: {e}") + + +# # Register the cleanup function to run when the script exits +# atexit.register(cleanup) diff --git a/litellm/proxy/queue/celery_worker.py b/litellm/proxy/queue/celery_worker.py index 7206723f8..41b0af515 100644 --- a/litellm/proxy/queue/celery_worker.py +++ b/litellm/proxy/queue/celery_worker.py @@ -1,12 +1,15 @@ import os from multiprocessing import Process + def run_worker(cwd): os.chdir(cwd) - os.system("celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info") + os.system( + "celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info" + ) + def start_worker(cwd): cwd += "/queue" worker_process = Process(target=run_worker, args=(cwd,)) worker_process.start() - \ No newline at end of file diff --git a/litellm/proxy/queue/rq_worker.py b/litellm/proxy/queue/rq_worker.py index 6e8ce29ae..7f9b34aed 100644 --- a/litellm/proxy/queue/rq_worker.py +++ b/litellm/proxy/queue/rq_worker.py @@ -1,26 +1,34 @@ -import sys, os -from dotenv import load_dotenv -load_dotenv() -# Add the path to the local folder to sys.path -sys.path.insert( - 0, os.path.abspath("../../..") -) # Adds the parent directory to the system path - for litellm local dev +# import sys, os +# from dotenv import load_dotenv -def start_rq_worker(): - from rq import Worker, Queue, Connection - from redis import Redis - # Set up RQ connection - redis_conn = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) - print(redis_conn.ping()) # Should print True if connected successfully - # Create a worker and add the queue - try: - queue = Queue(connection=redis_conn) - worker = Worker([queue], connection=redis_conn) - except Exception as e: - print(f"Error setting up worker: {e}") - exit() - - with Connection(redis_conn): - worker.work() +# load_dotenv() +# # Add the path to the local folder to sys.path +# sys.path.insert( +# 0, os.path.abspath("../../..") +# ) # Adds the parent directory to the system path - for litellm local dev -start_rq_worker() \ No newline at end of file + +# def start_rq_worker(): +# from rq import Worker, Queue, Connection +# from redis import Redis + +# # Set up RQ connection +# redis_conn = Redis( +# host=os.getenv("REDIS_HOST"), +# port=os.getenv("REDIS_PORT"), +# password=os.getenv("REDIS_PASSWORD"), +# ) +# print(redis_conn.ping()) # Should print True if connected successfully +# # Create a worker and add the queue +# try: +# queue = Queue(connection=redis_conn) +# worker = Worker([queue], connection=redis_conn) +# except Exception as e: +# print(f"Error setting up worker: {e}") +# exit() + +# with Connection(redis_conn): +# worker.work() + + +# start_rq_worker() diff --git a/litellm/proxy/tests/bursty_load_test_completion.py b/litellm/proxy/tests/bursty_load_test_completion.py index 11766c6ab..529f2ce9f 100644 --- a/litellm/proxy/tests/bursty_load_test_completion.py +++ b/litellm/proxy/tests/bursty_load_test_completion.py @@ -4,20 +4,18 @@ import uuid import traceback -litellm_client = AsyncOpenAI( - api_key="test", - base_url="http://0.0.0.0:8000" -) +litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000") async def litellm_completion(): # Your existing code for litellm_completion goes here try: - response = await litellm_client.chat.completions.create( + response = await litellm_client.chat.completions.create( model="gpt-3.5-turbo", - messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"*180}], # this is about 4k tokens per request + messages=[ + {"role": "user", "content": f"This is a test: {uuid.uuid4()}" * 180} + ], # this is about 4k tokens per request ) - print(response) return response except Exception as e: @@ -25,7 +23,6 @@ async def litellm_completion(): with open("error_log.txt", "a") as error_log: error_log.write(f"Error during completion: {str(e)}\n") pass - async def main(): @@ -45,6 +42,7 @@ async def main(): print(n, time.time() - start, len(successful_completions)) + if __name__ == "__main__": # Blank out contents of error_log.txt open("error_log.txt", "w").close() diff --git a/litellm/proxy/tests/load_test_completion.py b/litellm/proxy/tests/load_test_completion.py index 73b215020..53937440a 100644 --- a/litellm/proxy/tests/load_test_completion.py +++ b/litellm/proxy/tests/load_test_completion.py @@ -4,16 +4,13 @@ import uuid import traceback -litellm_client = AsyncOpenAI( - api_key="sk-1234", - base_url="http://0.0.0.0:8000" -) +litellm_client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:8000") async def litellm_completion(): # Your existing code for litellm_completion goes here try: - response = await litellm_client.chat.completions.create( + response = await litellm_client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], ) @@ -24,7 +21,6 @@ async def litellm_completion(): with open("error_log.txt", "a") as error_log: error_log.write(f"Error during completion: {str(e)}\n") pass - async def main(): @@ -44,6 +40,7 @@ async def main(): print(n, time.time() - start, len(successful_completions)) + if __name__ == "__main__": # Blank out contents of error_log.txt open("error_log.txt", "w").close() diff --git a/litellm/proxy/tests/load_test_embedding.py b/litellm/proxy/tests/load_test_embedding.py index 097dcf2c5..6771af188 100644 --- a/litellm/proxy/tests/load_test_embedding.py +++ b/litellm/proxy/tests/load_test_embedding.py @@ -1,4 +1,4 @@ -# test time it takes to make 100 concurrent embedding requests to OpenaI +# test time it takes to make 100 concurrent embedding requests to OpenaI import sys, os import traceback @@ -14,16 +14,16 @@ import pytest import litellm -litellm.set_verbose=False +litellm.set_verbose = False question = "embed this very long text" * 100 # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. -# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere -# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions +# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere +# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions import concurrent.futures import random @@ -35,7 +35,10 @@ def make_openai_completion(question): try: start_time = time.time() import openai - client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY']) #base_url="http://0.0.0.0:8000", + + client = openai.OpenAI( + api_key=os.environ["OPENAI_API_KEY"] + ) # base_url="http://0.0.0.0:8000", response = client.embeddings.create( model="text-embedding-ada-002", input=[question], @@ -58,6 +61,7 @@ def make_openai_completion(question): # ) return None + start_time = time.time() # Number of concurrent calls (you can adjust this) concurrent_calls = 500 diff --git a/litellm/proxy/tests/load_test_embedding_100.py b/litellm/proxy/tests/load_test_embedding_100.py index d7d272405..38ae8990a 100644 --- a/litellm/proxy/tests/load_test_embedding_100.py +++ b/litellm/proxy/tests/load_test_embedding_100.py @@ -4,24 +4,20 @@ import uuid import traceback -litellm_client = AsyncOpenAI( - api_key="test", - base_url="http://0.0.0.0:8000" -) - +litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000") async def litellm_completion(): # Your existing code for litellm_completion goes here try: print("starting embedding calls") - response = await litellm_client.embeddings.create( + response = await litellm_client.embeddings.create( model="text-embedding-ada-002", - input = [ - "hello who are you" * 2000, + input=[ + "hello who are you" * 2000, "hello who are you tomorrow 1234" * 1000, - "hello who are you tomorrow 1234" * 1000 - ] + "hello who are you tomorrow 1234" * 1000, + ], ) print(response) return response @@ -31,7 +27,6 @@ async def litellm_completion(): with open("error_log.txt", "a") as error_log: error_log.write(f"Error during completion: {str(e)}\n") pass - async def main(): @@ -51,6 +46,7 @@ async def main(): print(n, time.time() - start, len(successful_completions)) + if __name__ == "__main__": # Blank out contents of error_log.txt open("error_log.txt", "w").close() diff --git a/litellm/proxy/tests/load_test_embedding_proxy.py b/litellm/proxy/tests/load_test_embedding_proxy.py index 45136fafc..32c90cb8e 100644 --- a/litellm/proxy/tests/load_test_embedding_proxy.py +++ b/litellm/proxy/tests/load_test_embedding_proxy.py @@ -1,4 +1,4 @@ -# test time it takes to make 100 concurrent embedding requests to OpenaI +# test time it takes to make 100 concurrent embedding requests to OpenaI import sys, os import traceback @@ -14,16 +14,16 @@ import pytest import litellm -litellm.set_verbose=False +litellm.set_verbose = False question = "embed this very long text" * 100 # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. -# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere -# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions +# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere +# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions import concurrent.futures import random @@ -35,7 +35,10 @@ def make_openai_completion(question): try: start_time = time.time() import openai - client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000", + + client = openai.OpenAI( + api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000" + ) # base_url="http://0.0.0.0:8000", response = client.embeddings.create( model="text-embedding-ada-002", input=[question], @@ -58,6 +61,7 @@ def make_openai_completion(question): # ) return None + start_time = time.time() # Number of concurrent calls (you can adjust this) concurrent_calls = 500 diff --git a/litellm/proxy/tests/load_test_q.py b/litellm/proxy/tests/load_test_q.py index 6b1da1ece..657c27301 100644 --- a/litellm/proxy/tests/load_test_q.py +++ b/litellm/proxy/tests/load_test_q.py @@ -2,6 +2,7 @@ import requests import time import os from dotenv import load_dotenv + load_dotenv() @@ -12,37 +13,35 @@ base_url = "https://api.litellm.ai" # Step 1 Add a config to the proxy, generate a temp key config = { - "model_list": [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo", - "api_key": os.environ['OPENAI_API_KEY'], - } - }, - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.environ['AZURE_API_KEY'], - "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/", - "api_version": "2023-07-01-preview" - } - } - ] + "model_list": [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.environ["OPENAI_API_KEY"], + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.environ["AZURE_API_KEY"], + "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/", + "api_version": "2023-07-01-preview", + }, + }, + ] } print("STARTING LOAD TEST Q") -print(os.environ['AZURE_API_KEY']) +print(os.environ["AZURE_API_KEY"]) response = requests.post( url=f"{base_url}/key/generate", json={ "config": config, - "duration": "30d" # default to 30d, set it to 30m if you want a temp key + "duration": "30d", # default to 30d, set it to 30m if you want a temp key }, - headers={ - "Authorization": "Bearer sk-hosted-litellm" - } + headers={"Authorization": "Bearer sk-hosted-litellm"}, ) print("\nresponse from generating key", response.text) @@ -56,19 +55,18 @@ print("\ngenerated key for proxy", generated_key) import concurrent.futures + def create_job_and_poll(request_num): print(f"Creating a job on the proxy for request {request_num}") job_response = requests.post( url=f"{base_url}/queue/request", json={ - 'model': 'gpt-3.5-turbo', - 'messages': [ - {'role': 'system', 'content': 'write a short poem'}, + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "write a short poem"}, ], }, - headers={ - "Authorization": f"Bearer {generated_key}" - } + headers={"Authorization": f"Bearer {generated_key}"}, ) print(job_response.status_code) print(job_response.text) @@ -84,12 +82,12 @@ def create_job_and_poll(request_num): try: print(f"\nPolling URL for request {request_num}", polling_url) polling_response = requests.get( - url=polling_url, - headers={ - "Authorization": f"Bearer {generated_key}" - } + url=polling_url, headers={"Authorization": f"Bearer {generated_key}"} + ) + print( + f"\nResponse from polling url for request {request_num}", + polling_response.text, ) - print(f"\nResponse from polling url for request {request_num}", polling_response.text) polling_response = polling_response.json() status = polling_response.get("status", None) if status == "finished": @@ -109,6 +107,7 @@ def create_job_and_poll(request_num): except Exception as e: print("got exception when polling", e) + # Number of requests num_requests = 100 @@ -118,4 +117,4 @@ with concurrent.futures.ThreadPoolExecutor(max_workers=num_requests) as executor futures = [executor.submit(create_job_and_poll, i) for i in range(num_requests)] # Wait for all futures to complete - concurrent.futures.wait(futures) \ No newline at end of file + concurrent.futures.wait(futures) diff --git a/litellm/proxy/tests/test_async.py b/litellm/proxy/tests/test_async.py index fab10ee08..65d289853 100644 --- a/litellm/proxy/tests/test_async.py +++ b/litellm/proxy/tests/test_async.py @@ -1,4 +1,4 @@ -# # This tests the litelm proxy +# # This tests the litelm proxy # # it makes async Completion requests with streaming # import openai @@ -8,14 +8,14 @@ # async def test_async_completion(): # response = await ( -# model="gpt-3.5-turbo", +# model="gpt-3.5-turbo", # prompt='this is a test request, write a short poem', # ) # print(response) # print("test_streaming") # response = await openai.chat.completions.create( -# model="gpt-3.5-turbo", +# model="gpt-3.5-turbo", # prompt='this is a test request, write a short poem', # stream=True # ) @@ -26,4 +26,3 @@ # import asyncio # asyncio.run(test_async_completion()) - diff --git a/litellm/proxy/tests/test_langchain_request.py b/litellm/proxy/tests/test_langchain_request.py index 1841b4968..9776f4134 100644 --- a/litellm/proxy/tests/test_langchain_request.py +++ b/litellm/proxy/tests/test_langchain_request.py @@ -34,7 +34,3 @@ # response = claude_chat(messages) # print(response) - - - - diff --git a/litellm/proxy/tests/test_q.py b/litellm/proxy/tests/test_q.py index 34e713a5d..5878f21ad 100644 --- a/litellm/proxy/tests/test_q.py +++ b/litellm/proxy/tests/test_q.py @@ -2,6 +2,7 @@ import requests import time import os from dotenv import load_dotenv + load_dotenv() @@ -12,26 +13,24 @@ base_url = "https://api.litellm.ai" # Step 1 Add a config to the proxy, generate a temp key config = { - "model_list": [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo", - "api_key": os.environ['OPENAI_API_KEY'], - } - } - ] + "model_list": [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.environ["OPENAI_API_KEY"], + }, + } + ] } response = requests.post( url=f"{base_url}/key/generate", json={ "config": config, - "duration": "30d" # default to 30d, set it to 30m if you want a temp key + "duration": "30d", # default to 30d, set it to 30m if you want a temp key }, - headers={ - "Authorization": "Bearer sk-hosted-litellm" - } + headers={"Authorization": "Bearer sk-hosted-litellm"}, ) print("\nresponse from generating key", response.text) @@ -45,22 +44,23 @@ print("Creating a job on the proxy") job_response = requests.post( url=f"{base_url}/queue/request", json={ - 'model': 'gpt-3.5-turbo', - 'messages': [ - {'role': 'system', 'content': f'You are a helpful assistant. What is your name'}, + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "system", + "content": f"You are a helpful assistant. What is your name", + }, ], }, - headers={ - "Authorization": f"Bearer {generated_key}" - } + headers={"Authorization": f"Bearer {generated_key}"}, ) print(job_response.status_code) print(job_response.text) print("\nResponse from creating job", job_response.text) job_response = job_response.json() -job_id = job_response["id"] # type: ignore -polling_url = job_response["url"] # type: ignore -polling_url = f"{base_url}{polling_url}" +job_id = job_response["id"] # type: ignore +polling_url = job_response["url"] # type: ignore +polling_url = f"{base_url}{polling_url}" print("\nCreated Job, Polling Url", polling_url) # Step 3: Poll the request @@ -68,16 +68,13 @@ while True: try: print("\nPolling URL", polling_url) polling_response = requests.get( - url=polling_url, - headers={ - "Authorization": f"Bearer {generated_key}" - } + url=polling_url, headers={"Authorization": f"Bearer {generated_key}"} ) print("\nResponse from polling url", polling_response.text) polling_response = polling_response.json() - status = polling_response.get("status", None) # type: ignore + status = polling_response.get("status", None) # type: ignore if status == "finished": - llm_response = polling_response["result"] # type: ignore + llm_response = polling_response["result"] # type: ignore print("LLM Response") print(llm_response) break diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 94aa51921..589c36ee7 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -8,16 +8,19 @@ from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException, status + def print_verbose(print_statement): if litellm.set_verbose: - print(f"LiteLLM Proxy: {print_statement}") # noqa -### LOGGING ### -class ProxyLogging: + print(f"LiteLLM Proxy: {print_statement}") # noqa + + +### LOGGING ### +class ProxyLogging: """ - Logging/Custom Handlers for proxy. + Logging/Custom Handlers for proxy. Implemented mainly to: - - log successful/failed db read/writes + - log successful/failed db read/writes - support the max parallel request integration """ @@ -25,15 +28,15 @@ class ProxyLogging: ## INITIALIZE LITELLM CALLBACKS ## self.call_details: dict = {} self.call_details["user_api_key_cache"] = user_api_key_cache - self.max_parallel_request_limiter = MaxParallelRequestsHandler() - self.max_budget_limiter = MaxBudgetLimiter() + self.max_parallel_request_limiter = MaxParallelRequestsHandler() + self.max_budget_limiter = MaxBudgetLimiter() pass def _init_litellm_callbacks(self): print_verbose(f"INITIALIZING LITELLM CALLBACKS!") litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_budget_limiter) - for callback in litellm.callbacks: + for callback in litellm.callbacks: if callback not in litellm.input_callback: litellm.input_callback.append(callback) if callback not in litellm.success_callback: @@ -44,7 +47,7 @@ class ProxyLogging: litellm._async_success_callback.append(callback) if callback not in litellm._async_failure_callback: litellm._async_failure_callback.append(callback) - + if ( len(litellm.input_callback) > 0 or len(litellm.success_callback) > 0 @@ -57,31 +60,41 @@ class ProxyLogging: + litellm.failure_callback ) ) - litellm.utils.set_callbacks( - callback_list=callback_list - ) + litellm.utils.set_callbacks(callback_list=callback_list) - async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]): + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: dict, + call_type: Literal["completion", "embeddings"], + ): """ Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. - Covers: + Covers: 1. /chat/completions - 2. /embeddings + 2. /embeddings """ try: - for callback in litellm.callbacks: - if isinstance(callback, CustomLogger) and 'async_pre_call_hook' in vars(callback.__class__): - response = await callback.async_pre_call_hook(user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type) - if response is not None: + for callback in litellm.callbacks: + if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars( + callback.__class__ + ): + response = await callback.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=self.call_details["user_api_key_cache"], + data=data, + call_type=call_type, + ) + if response is not None: data = response - print_verbose(f'final data being sent to {call_type} call: {data}') + print_verbose(f"final data being sent to {call_type} call: {data}") return data except Exception as e: raise e - - async def success_handler(self, *args, **kwargs): + + async def success_handler(self, *args, **kwargs): """ Log successful db read/writes """ @@ -93,26 +106,31 @@ class ProxyLogging: Currently only logs exceptions to sentry """ - if litellm.utils.capture_exception: + if litellm.utils.capture_exception: litellm.utils.capture_exception(error=original_exception) - async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth): + async def post_call_failure_hook( + self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth + ): """ Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body. - Covers: + Covers: 1. /chat/completions - 2. /embeddings + 2. /embeddings """ for callback in litellm.callbacks: - try: + try: if isinstance(callback, CustomLogger): - await callback.async_post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=original_exception) - except Exception as e: + await callback.async_post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=original_exception, + ) + except Exception as e: raise e return - + ### DB CONNECTOR ### # Define the retry decorator with backoff strategy @@ -121,9 +139,12 @@ def on_backoff(details): # The 'tries' key in the details dictionary contains the number of completed tries print_verbose(f"Backing off... this was attempt #{details['tries']}") + class PrismaClient: def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): - print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") + print_verbose( + "LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" + ) ## init logging object self.proxy_logging_obj = proxy_logging_obj @@ -136,23 +157,24 @@ class PrismaClient: os.chdir(dname) try: - subprocess.run(['prisma', 'generate']) - subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss + subprocess.run(["prisma", "generate"]) + subprocess.run( + ["prisma", "db", "push", "--accept-data-loss"] + ) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss finally: os.chdir(original_dir) # Now you can import the Prisma Client - from prisma import Client # type: ignore - self.db = Client() #Client to connect to Prisma db + from prisma import Client # type: ignore - + self.db = Client() # Client to connect to Prisma db def hash_token(self, token: str): # Hash the string using SHA-256 hashed_token = hashlib.sha256(token.encode()).hexdigest() - + return hashed_token - def jsonify_object(self, data: dict) -> dict: + def jsonify_object(self, data: dict) -> dict: db_data = copy.deepcopy(data) for k, v in db_data.items(): @@ -162,233 +184,258 @@ class PrismaClient: @backoff.on_exception( backoff.expo, - Exception, # base exception to catch for the backoff - max_tries=3, # maximum number of retries - max_time=10, # maximum total time to retry for + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) - async def get_data(self, token: Optional[str]=None, expires: Optional[Any]=None, user_id: Optional[str]=None): - try: + async def get_data( + self, + token: Optional[str] = None, + expires: Optional[Any] = None, + user_id: Optional[str] = None, + ): + try: response = None - if token is not None: + if token is not None: # check if plain text or hash hashed_token = token - if token.startswith("sk-"): + if token.startswith("sk-"): hashed_token = self.hash_token(token=token) response = await self.db.litellm_verificationtoken.find_unique( - where={ - "token": hashed_token - } - ) + where={"token": hashed_token} + ) if response: # Token exists, now check expiration. - if response.expires is not None and expires is not None: + if response.expires is not None and expires is not None: if response.expires >= expires: # Token exists and is not expired. return response else: # Token exists but is expired. - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="expired user key") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="expired user key", + ) return response else: # Token does not exist. - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key") - elif user_id is not None: - response = await self.db.litellm_usertable.find_unique( # type: ignore - where={ - "user_id": user_id, - } - ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="invalid user key", + ) + elif user_id is not None: + response = await self.db.litellm_usertable.find_unique( # type: ignore + where={ + "user_id": user_id, + } + ) return response - except Exception as e: - asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) + except Exception as e: + asyncio.create_task( + self.proxy_logging_obj.failure_handler(original_exception=e) + ) raise e # Define a retrying strategy with exponential backoff @backoff.on_exception( backoff.expo, - Exception, # base exception to catch for the backoff - max_tries=3, # maximum number of retries - max_time=10, # maximum total time to retry for + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) async def insert_data(self, data: dict): """ - Add a key to the database. If it already exists, do nothing. + Add a key to the database. If it already exists, do nothing. """ - try: + try: token = data["token"] hashed_token = self.hash_token(token=token) db_data = self.jsonify_object(data=data) db_data["token"] = hashed_token max_budget = db_data.pop("max_budget", None) - new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore + new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore where={ - 'token': hashed_token, + "token": hashed_token, }, data={ - "create": {**db_data}, #type: ignore - "update": {} # don't do anything if it already exists - } + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, ) new_user_row = await self.db.litellm_usertable.upsert( - where={ - 'user_id': data['user_id'] - }, + where={"user_id": data["user_id"]}, data={ - "create": {"user_id": data['user_id'], "max_budget": max_budget}, - "update": {} # don't do anything if it already exists - } + "create": {"user_id": data["user_id"], "max_budget": max_budget}, + "update": {}, # don't do anything if it already exists + }, ) return new_verification_token except Exception as e: - asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) + asyncio.create_task( + self.proxy_logging_obj.failure_handler(original_exception=e) + ) raise e # Define a retrying strategy with exponential backoff @backoff.on_exception( backoff.expo, - Exception, # base exception to catch for the backoff - max_tries=3, # maximum number of retries - max_time=10, # maximum total time to retry for + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) - async def update_data(self, token: Optional[str]=None, data: dict={}, user_id: Optional[str]=None): + async def update_data( + self, + token: Optional[str] = None, + data: dict = {}, + user_id: Optional[str] = None, + ): """ Update existing data """ - try: + try: db_data = self.jsonify_object(data=data) - if token is not None: + if token is not None: print_verbose(f"token: {token}") # check if plain text or hash - if token.startswith("sk-"): + if token.startswith("sk-"): token = self.hash_token(token=token) - db_data["token"] = token + db_data["token"] = token response = await self.db.litellm_verificationtoken.update( - where={ - "token": token # type: ignore - }, - data={**db_data} # type: ignore + where={"token": token}, # type: ignore + data={**db_data}, # type: ignore ) print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") return {"token": token, "data": db_data} - elif user_id is not None: + elif user_id is not None: """ If data['spend'] + data['user'], update the user table with spend info as well """ update_user_row = await self.db.litellm_usertable.update( - where={ - 'user_id': user_id # type: ignore - }, - data={**db_data} # type: ignore + where={"user_id": user_id}, # type: ignore + data={**db_data}, # type: ignore ) return {"user_id": user_id, "data": db_data} - except Exception as e: - asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) + except Exception as e: + asyncio.create_task( + self.proxy_logging_obj.failure_handler(original_exception=e) + ) print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m") raise e - # Define a retrying strategy with exponential backoff @backoff.on_exception( backoff.expo, - Exception, # base exception to catch for the backoff - max_tries=3, # maximum number of retries - max_time=10, # maximum total time to retry for + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) async def delete_data(self, tokens: List): """ Allow user to delete a key(s) """ - try: + try: hashed_tokens = [self.hash_token(token=token) for token in tokens] await self.db.litellm_verificationtoken.delete_many( - where={"token": {"in": hashed_tokens}} - ) + where={"token": {"in": hashed_tokens}} + ) return {"deleted_keys": tokens} - except Exception as e: - asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) - raise e - - # Define a retrying strategy with exponential backoff - @backoff.on_exception( - backoff.expo, - Exception, # base exception to catch for the backoff - max_tries=3, # maximum number of retries - max_time=10, # maximum total time to retry for - on_backoff=on_backoff, # specifying the function to call on backoff - ) - async def connect(self): - try: - await self.db.connect() - except Exception as e: - asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) + except Exception as e: + asyncio.create_task( + self.proxy_logging_obj.failure_handler(original_exception=e) + ) raise e # Define a retrying strategy with exponential backoff @backoff.on_exception( backoff.expo, - Exception, # base exception to catch for the backoff - max_tries=3, # maximum number of retries - max_time=10, # maximum total time to retry for + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) - async def disconnect(self): + async def connect(self): + try: + await self.db.connect() + except Exception as e: + asyncio.create_task( + self.proxy_logging_obj.failure_handler(original_exception=e) + ) + raise e + + # Define a retrying strategy with exponential backoff + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def disconnect(self): try: await self.db.disconnect() - except Exception as e: - asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) + except Exception as e: + asyncio.create_task( + self.proxy_logging_obj.failure_handler(original_exception=e) + ) raise e + ### CUSTOM FILE ### def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: try: print_verbose(f"value: {value}") # Split the path by dots to separate module from instance parts = value.split(".") - + # The module path is all but the last part, and the instance_name is the last part module_name = ".".join(parts[:-1]) instance_name = parts[-1] - + # If config_file_path is provided, use it to determine the module spec and load the module if config_file_path is not None: directory = os.path.dirname(config_file_path) - module_file_path = os.path.join(directory, *module_name.split('.')) - module_file_path += '.py' + module_file_path = os.path.join(directory, *module_name.split(".")) + module_file_path += ".py" spec = importlib.util.spec_from_file_location(module_name, module_file_path) if spec is None: - raise ImportError(f"Could not find a module specification for {module_file_path}") + raise ImportError( + f"Could not find a module specification for {module_file_path}" + ) module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) # type: ignore + spec.loader.exec_module(module) # type: ignore else: # Dynamically import the module module = importlib.import_module(module_name) - + # Get the instance from the module instance = getattr(module, instance_name) - + return instance except ImportError as e: # Re-raise the exception with a user-friendly message raise ImportError(f"Could not import {instance_name} from {module_name}") from e - except Exception as e: + except Exception as e: raise e + ### HELPER FUNCTIONS ### -async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): +async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): """ - Check if a user_id exists in cache, - if not retrieve it. + Check if a user_id exists in cache, + if not retrieve it. """ cache_key = f"{user_id}_user_api_key_user_id" response = cache.get_cache(key=cache_key) - if response is None: # Cache miss + if response is None: # Cache miss user_row = await db.get_data(user_id=user_id) cache_value = user_row.model_dump_json() - cache.set_cache(key=cache_key, value=cache_value, ttl=600) # store for 10 minutes - return \ No newline at end of file + cache.set_cache( + key=cache_key, value=cache_value, ttl=600 + ) # store for 10 minutes + return diff --git a/litellm/router.py b/litellm/router.py index c176ab60a..6e1b77748 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5,7 +5,7 @@ # | | # +-----------------------------------------------+ # -# Thank you ! We ❤️ you! - Krrish & Ishaan +# Thank you ! We ❤️ you! - Krrish & Ishaan import copy, httpx from datetime import datetime @@ -18,8 +18,13 @@ import inspect, concurrent from openai import AsyncOpenAI from collections import defaultdict from litellm.router_strategy.least_busy import LeastBusyLoggingHandler -from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport +from litellm.llms.custom_httpx.azure_dall_e_2 import ( + CustomHTTPTransport, + AsyncCustomHTTPTransport, +) import copy + + class Router: """ Example usage: @@ -27,7 +32,7 @@ class Router: from litellm import Router model_list = [ { - "model_name": "azure-gpt-3.5-turbo", # model alias + "model_name": "azure-gpt-3.5-turbo", # model alias "litellm_params": { # params for litellm completion/embedding call "model": "azure/", "api_key": , @@ -36,7 +41,7 @@ class Router: }, }, { - "model_name": "azure-gpt-3.5-turbo", # model alias + "model_name": "azure-gpt-3.5-turbo", # model alias "litellm_params": { # params for litellm completion/embedding call "model": "azure/", "api_key": , @@ -45,7 +50,7 @@ class Router: }, }, { - "model_name": "openai-gpt-3.5-turbo", # model alias + "model_name": "openai-gpt-3.5-turbo", # model alias "litellm_params": { # params for litellm completion/embedding call "model": "gpt-3.5-turbo", "api_key": , @@ -55,6 +60,7 @@ class Router: router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}]) ``` """ + model_names: List = [] cache_responses: Optional[bool] = False default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour @@ -62,51 +68,77 @@ class Router: tenacity = None leastbusy_logger: Optional[LeastBusyLoggingHandler] = None - def __init__(self, - model_list: Optional[list] = None, - ## CACHING ## - redis_url: Optional[str] = None, - redis_host: Optional[str] = None, - redis_port: Optional[int] = None, - redis_password: Optional[str] = None, - cache_responses: Optional[bool] = False, - cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py) - caching_groups: Optional[List[tuple]] = None, # if you want to cache across model groups - ## RELIABILITY ## - num_retries: int = 0, - timeout: Optional[float] = None, - default_litellm_params = {}, # default params for Router.chat.completion.create - set_verbose: bool = False, - fallbacks: List = [], - allowed_fails: Optional[int] = None, - context_window_fallbacks: List = [], - model_group_alias: Optional[dict] = {}, - routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None: - - self.set_verbose = set_verbose - self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2 + def __init__( + self, + model_list: Optional[list] = None, + ## CACHING ## + redis_url: Optional[str] = None, + redis_host: Optional[str] = None, + redis_port: Optional[int] = None, + redis_password: Optional[str] = None, + cache_responses: Optional[bool] = False, + cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py) + caching_groups: Optional[ + List[tuple] + ] = None, # if you want to cache across model groups + ## RELIABILITY ## + num_retries: int = 0, + timeout: Optional[float] = None, + default_litellm_params={}, # default params for Router.chat.completion.create + set_verbose: bool = False, + fallbacks: List = [], + allowed_fails: Optional[int] = None, + context_window_fallbacks: List = [], + model_group_alias: Optional[dict] = {}, + routing_strategy: Literal[ + "simple-shuffle", + "least-busy", + "usage-based-routing", + "latency-based-routing", + ] = "simple-shuffle", + ) -> None: + self.set_verbose = set_verbose + self.deployment_names: List = ( + [] + ) # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_latency_map = {} if model_list: model_list = copy.deepcopy(model_list) self.set_model_list(model_list) self.healthy_deployments: List = self.model_list - for m in model_list: + for m in model_list: self.deployment_latency_map[m["litellm_params"]["model"]] = 0 - + self.allowed_fails = allowed_fails or litellm.allowed_fails - self.failed_calls = InMemoryCache() # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown + self.failed_calls = ( + InMemoryCache() + ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown self.num_retries = num_retries or litellm.num_retries or 0 self.timeout = timeout or litellm.request_timeout self.routing_strategy = routing_strategy self.fallbacks = fallbacks or litellm.fallbacks - self.context_window_fallbacks = context_window_fallbacks or litellm.context_window_fallbacks - self.model_exception_map: dict = {} # dict to store model: list exceptions. self.exceptions = {"gpt-3.5": ["API KEY Error", "Rate Limit Error", "good morning error"]} - self.total_calls: defaultdict = defaultdict(int) # dict to store total calls made to each model - self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model - self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model - self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call) - self.model_group_alias: dict = model_group_alias or {} # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group - + self.context_window_fallbacks = ( + context_window_fallbacks or litellm.context_window_fallbacks + ) + self.model_exception_map: dict = ( + {} + ) # dict to store model: list exceptions. self.exceptions = {"gpt-3.5": ["API KEY Error", "Rate Limit Error", "good morning error"]} + self.total_calls: defaultdict = defaultdict( + int + ) # dict to store total calls made to each model + self.fail_calls: defaultdict = defaultdict( + int + ) # dict to store fail_calls made to each model + self.success_calls: defaultdict = defaultdict( + int + ) # dict to store success_calls made to each model + self.previous_models: List = ( + [] + ) # list to store failed calls (passed in as metadata to next call) + self.model_group_alias: dict = ( + model_group_alias or {} + ) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group + # make Router.chat.completions.create compatible for openai.chat.completions.create self.chat = litellm.Chat(params=default_litellm_params) @@ -114,26 +146,32 @@ class Router: self.default_litellm_params = default_litellm_params self.default_litellm_params.setdefault("timeout", timeout) self.default_litellm_params.setdefault("max_retries", 0) - self.default_litellm_params.setdefault("metadata", {}).update({"caching_groups": caching_groups}) + self.default_litellm_params.setdefault("metadata", {}).update( + {"caching_groups": caching_groups} + ) ### CACHING ### - cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache + cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache redis_cache = None - cache_config = {} - if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None): + cache_config = {} + if redis_url is not None or ( + redis_host is not None + and redis_port is not None + and redis_password is not None + ): cache_type = "redis" if redis_url is not None: - cache_config['url'] = redis_url + cache_config["url"] = redis_url if redis_host is not None: - cache_config['host'] = redis_host + cache_config["host"] = redis_host if redis_port is not None: - cache_config['port'] = str(redis_port) # type: ignore + cache_config["port"] = str(redis_port) # type: ignore if redis_password is not None: - cache_config['password'] = redis_password + cache_config["password"] = redis_password # Add additional key-value pairs from cache_kwargs cache_config.update(cache_kwargs) @@ -141,43 +179,43 @@ class Router: if cache_responses: if litellm.cache is None: # the cache can be initialized on the proxy server. We should not overwrite it - litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore + litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore self.cache_responses = cache_responses - self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. - ### ROUTING SETUP ### + self.cache = DualCache( + redis_cache=redis_cache, in_memory_cache=InMemoryCache() + ) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. + ### ROUTING SETUP ### if routing_strategy == "least-busy": - self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache) + self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache) ## add callback - if isinstance(litellm.input_callback, list): - litellm.input_callback.append(self.leastbusy_logger) # type: ignore - else: - litellm.input_callback = [self.leastbusy_logger] # type: ignore + if isinstance(litellm.input_callback, list): + litellm.input_callback.append(self.leastbusy_logger) # type: ignore + else: + litellm.input_callback = [self.leastbusy_logger] # type: ignore if isinstance(litellm.callbacks, list): - litellm.callbacks.append(self.leastbusy_logger) # type: ignore - ## USAGE TRACKING ## + litellm.callbacks.append(self.leastbusy_logger) # type: ignore + ## USAGE TRACKING ## if isinstance(litellm.success_callback, list): litellm.success_callback.append(self.deployment_callback) else: litellm.success_callback = [self.deployment_callback] - + if isinstance(litellm.failure_callback, list): litellm.failure_callback.append(self.deployment_callback_on_failure) else: litellm.failure_callback = [self.deployment_callback_on_failure] - self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n") + self.print_verbose( + f"Intialized router with Routing strategy: {self.routing_strategy}\n" + ) - ### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS - def completion(self, - model: str, - messages: List[Dict[str, str]], - **kwargs): + def completion(self, model: str, messages: List[Dict[str, str]], **kwargs): """ Example usage: response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}] """ - try: + try: kwargs["model"] = model kwargs["messages"] = messages kwargs["original_function"] = self._completion @@ -187,39 +225,47 @@ class Router: with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: # Submit the function to the executor with a timeout future = executor.submit(self.function_with_fallbacks, **kwargs) - response = future.result(timeout=timeout) # type: ignore + response = future.result(timeout=timeout) # type: ignore return response - except Exception as e: + except Exception as e: raise e - def _completion( - self, - model: str, - messages: List[Dict[str, str]], - **kwargs): - - try: + def _completion(self, model: str, messages: List[Dict[str, str]], **kwargs): + try: # pick the one that is available (lowest TPM/RPM) - deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None)) - kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) + deployment = self.get_available_deployment( + model=model, + messages=messages, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + {"deployment": deployment["litellm_params"]["model"]} + ) data = deployment["litellm_params"].copy() kwargs["model_info"] = deployment.get("model_info", {}) - for k, v in self.default_litellm_params.items(): - if k not in kwargs: # prioritize model-specific params > default router params + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) model_client = self._get_client(deployment=deployment, kwargs=kwargs) - return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs}) - except Exception as e: + return litellm.completion( + **{ + **data, + "messages": messages, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + except Exception as e: raise e - - async def acompletion(self, - model: str, - messages: List[Dict[str, str]], - **kwargs): - try: + + async def acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): + try: kwargs["model"] = model kwargs["messages"] = messages kwargs["original_function"] = self._acompletion @@ -230,42 +276,55 @@ class Router: response = await self.async_function_with_fallbacks(**kwargs) return response - except Exception as e: + except Exception as e: raise e - async def _acompletion( - self, - model: str, - messages: List[Dict[str, str]], - **kwargs): - try: - self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}") - deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None)) - kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) + async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): + try: + self.print_verbose( + f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + messages=messages, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + {"deployment": deployment["litellm_params"]["model"]} + ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if k not in kwargs: # prioritize model-specific params > default router params + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) - model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") - self.total_calls[model_name] +=1 - response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs}) - self.success_calls[model_name] +=1 + model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + self.total_calls[model_name] += 1 + response = await litellm.acompletion( + **{ + **data, + "messages": messages, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + self.success_calls[model_name] += 1 return response - except Exception as e: + except Exception as e: if model_name is not None: - self.fail_calls[model_name] +=1 + self.fail_calls[model_name] += 1 raise e - - def image_generation(self, - prompt: str, - model: str, - **kwargs): - try: + + def image_generation(self, prompt: str, model: str, **kwargs): + try: kwargs["model"] = model kwargs["prompt"] = prompt kwargs["original_function"] = self._image_generation @@ -275,41 +334,55 @@ class Router: response = self.function_with_fallbacks(**kwargs) return response - except Exception as e: + except Exception as e: raise e - - def _image_generation(self, - prompt: str, - model: str, - **kwargs): - try: - self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}") - deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None)) - kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) + + def _image_generation(self, prompt: str, model: str, **kwargs): + try: + self.print_verbose( + f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + messages=[{"role": "user", "content": "prompt"}], + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + {"deployment": deployment["litellm_params"]["model"]} + ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if k not in kwargs: # prioritize model-specific params > default router params + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) - model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") - self.total_calls[model_name] +=1 - response = litellm.image_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs}) - self.success_calls[model_name] +=1 + model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + self.total_calls[model_name] += 1 + response = litellm.image_generation( + **{ + **data, + "prompt": prompt, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + self.success_calls[model_name] += 1 return response - except Exception as e: + except Exception as e: if model_name is not None: - self.fail_calls[model_name] +=1 + self.fail_calls[model_name] += 1 raise e - - async def aimage_generation(self, - prompt: str, - model: str, - **kwargs): - try: + + async def aimage_generation(self, prompt: str, model: str, **kwargs): + try: kwargs["model"] = model kwargs["prompt"] = prompt kwargs["original_function"] = self._aimage_generation @@ -319,84 +392,117 @@ class Router: response = await self.async_function_with_fallbacks(**kwargs) return response - except Exception as e: + except Exception as e: raise e - - async def _aimage_generation(self, - prompt: str, - model: str, - **kwargs): - try: - self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}") - deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None)) - kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) + + async def _aimage_generation(self, prompt: str, model: str, **kwargs): + try: + self.print_verbose( + f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + messages=[{"role": "user", "content": "prompt"}], + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + {"deployment": deployment["litellm_params"]["model"]} + ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] - for k, v in self.default_litellm_params.items(): - if k not in kwargs: # prioritize model-specific params > default router params + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) - model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") - self.total_calls[model_name] +=1 - response = await litellm.aimage_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs}) - self.success_calls[model_name] +=1 + model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + self.total_calls[model_name] += 1 + response = await litellm.aimage_generation( + **{ + **data, + "prompt": prompt, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + self.success_calls[model_name] += 1 return response - except Exception as e: + except Exception as e: if model_name is not None: - self.fail_calls[model_name] +=1 + self.fail_calls[model_name] += 1 raise e - def text_completion(self, - model: str, - prompt: str, - is_retry: Optional[bool] = False, - is_fallback: Optional[bool] = False, - is_async: Optional[bool] = False, - **kwargs): - try: + def text_completion( + self, + model: str, + prompt: str, + is_retry: Optional[bool] = False, + is_fallback: Optional[bool] = False, + is_async: Optional[bool] = False, + **kwargs, + ): + try: kwargs.setdefault("metadata", {}).update({"model_group": model}) - messages=[{"role": "user", "content": prompt}] + messages = [{"role": "user", "content": prompt}] # pick the one that is available (lowest TPM/RPM) - deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None)) + deployment = self.get_available_deployment( + model=model, + messages=messages, + specific_deployment=kwargs.pop("specific_deployment", None), + ) data = deployment["litellm_params"].copy() - for k, v in self.default_litellm_params.items(): - if k not in kwargs: # prioritize model-specific params > default router params + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) # call via litellm.completion() - return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore - except Exception as e: + return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore + except Exception as e: if self.num_retries > 0: kwargs["model"] = model kwargs["messages"] = messages kwargs["original_exception"] = e kwargs["original_function"] = self.completion return self.function_with_retries(**kwargs) - else: + else: raise e - - async def atext_completion(self, - model: str, - prompt: str, - is_retry: Optional[bool] = False, - is_fallback: Optional[bool] = False, - is_async: Optional[bool] = False, - **kwargs): - try: + + async def atext_completion( + self, + model: str, + prompt: str, + is_retry: Optional[bool] = False, + is_fallback: Optional[bool] = False, + is_async: Optional[bool] = False, + **kwargs, + ): + try: kwargs.setdefault("metadata", {}).update({"model_group": model}) - messages=[{"role": "user", "content": prompt}] + messages = [{"role": "user", "content": prompt}] # pick the one that is available (lowest TPM/RPM) - deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None)) + deployment = self.get_available_deployment( + model=model, + messages=messages, + specific_deployment=kwargs.pop("specific_deployment", None), + ) data = deployment["litellm_params"].copy() - for k, v in self.default_litellm_params.items(): - if k not in kwargs: # prioritize model-specific params > default router params + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) @@ -411,101 +517,150 @@ class Router: else: data["model"] = original_model_string # call via litellm.atext_completion() - response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore + response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore return response - except Exception as e: + except Exception as e: if self.num_retries > 0: kwargs["model"] = model kwargs["messages"] = messages kwargs["original_exception"] = e kwargs["original_function"] = self.completion return self.function_with_retries(**kwargs) - else: + else: raise e - def embedding(self, - model: str, - input: Union[str, List], - is_async: Optional[bool] = False, - **kwargs) -> Union[List[float], None]: + def embedding( + self, + model: str, + input: Union[str, List], + is_async: Optional[bool] = False, + **kwargs, + ) -> Union[List[float], None]: # pick the one that is available (lowest TPM/RPM) - deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None)) + deployment = self.get_available_deployment( + model=model, + input=input, + specific_deployment=kwargs.pop("specific_deployment", None), + ) kwargs.setdefault("model_info", {}) - kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks + kwargs.setdefault("metadata", {}).update( + {"model_group": model, "deployment": deployment["litellm_params"]["model"]} + ) # [TODO]: move to using async_function_with_fallbacks data = deployment["litellm_params"].copy() - for k, v in self.default_litellm_params.items(): - if k not in kwargs: # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) model_client = self._get_client(deployment=deployment, kwargs=kwargs) # call via litellm.embedding() - return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs}) + return litellm.embedding( + **{ + **data, + "input": input, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) - async def aembedding(self, - model: str, - input: Union[str, List], - is_async: Optional[bool] = True, - **kwargs) -> Union[List[float], None]: + async def aembedding( + self, + model: str, + input: Union[str, List], + is_async: Optional[bool] = True, + **kwargs, + ) -> Union[List[float], None]: # pick the one that is available (lowest TPM/RPM) - deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None)) - kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) + deployment = self.get_available_deployment( + model=model, + input=input, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + {"model_group": model, "deployment": deployment["litellm_params"]["model"]} + ) data = deployment["litellm_params"].copy() kwargs["model_info"] = deployment.get("model_info", {}) - for k, v in self.default_litellm_params.items(): - if k not in kwargs: # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) - model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") - - return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs}) + model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) - async def async_function_with_fallbacks(self, *args, **kwargs): + return await litellm.aembedding( + **{ + **data, + "input": input, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + + async def async_function_with_fallbacks(self, *args, **kwargs): """ Try calling the function_with_retries If it fails after num_retries, fall back to another model group """ model_group = kwargs.get("model") fallbacks = kwargs.get("fallbacks", self.fallbacks) - context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks) - try: + context_window_fallbacks = kwargs.get( + "context_window_fallbacks", self.context_window_fallbacks + ) + try: response = await self.async_function_with_retries(*args, **kwargs) - self.print_verbose(f'Async Response: {response}') + self.print_verbose(f"Async Response: {response}") return response - except Exception as e: - self.print_verbose(f"An exception occurs: {e}\n\n Traceback{traceback.format_exc()}") + except Exception as e: + self.print_verbose( + f"An exception occurs: {e}\n\n Traceback{traceback.format_exc()}" + ) original_exception = e - try: + try: self.print_verbose(f"Trying to fallback b/w models") - if isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None: + if ( + isinstance(e, litellm.ContextWindowExceededError) + and context_window_fallbacks is not None + ): fallback_model_group = None - for item in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] + for ( + item + ) in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break - - if fallback_model_group is None: + + if fallback_model_group is None: raise original_exception - - for mg in fallback_model_group: + + for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment """ try: kwargs["model"] = mg - response = await self.async_function_with_retries(*args, **kwargs) - return response - except Exception as e: + response = await self.async_function_with_retries( + *args, **kwargs + ) + return response + except Exception as e: pass - elif fallbacks is not None: + elif fallbacks is not None: self.print_verbose(f"inside model fallbacks: {fallbacks}") for item in fallbacks: if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break - for mg in fallback_model_group: + for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment """ @@ -514,104 +669,157 @@ class Router: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs["model"] = mg kwargs["metadata"]["model_group"] = mg - response = await self.async_function_with_retries(*args, **kwargs) - return response - except Exception as e: + response = await self.async_function_with_retries( + *args, **kwargs + ) + return response + except Exception as e: raise e - except Exception as e: + except Exception as e: self.print_verbose(f"An exception occurred - {str(e)}") traceback.print_exc() raise original_exception - + async def async_function_with_retries(self, *args, **kwargs): - self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}") + self.print_verbose( + f"Inside async function with retries: args - {args}; kwargs - {kwargs}" + ) original_function = kwargs.pop("original_function") fallbacks = kwargs.pop("fallbacks", self.fallbacks) - context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks) - self.print_verbose(f"async function w/ retries: original_function - {original_function}") + context_window_fallbacks = kwargs.pop( + "context_window_fallbacks", self.context_window_fallbacks + ) + self.print_verbose( + f"async function w/ retries: original_function - {original_function}" + ) num_retries = kwargs.pop("num_retries") - try: + try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = await original_function(*args, **kwargs) return response - except Exception as e: + except Exception as e: original_exception = e ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available - if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None) - or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)): + if ( + isinstance(original_exception, litellm.ContextWindowExceededError) + and context_window_fallbacks is None + ) or ( + isinstance(original_exception, openai.RateLimitError) + and fallbacks is not None + ): raise original_exception ### RETRY #### check if it should retry + back-off if required - if "No models available" in str(e): - timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries) + if "No models available" in str(e): + timeout = litellm._calculate_retry_after( + remaining_retries=num_retries, max_retries=num_retries + ) await asyncio.sleep(timeout) - elif hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code): + elif ( + hasattr(original_exception, "status_code") + and hasattr(original_exception, "response") + and litellm._should_retry(status_code=original_exception.status_code) + ): if hasattr(original_exception.response, "headers"): - timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers) + timeout = litellm._calculate_retry_after( + remaining_retries=num_retries, + max_retries=num_retries, + response_headers=original_exception.response.headers, + ) else: - timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries) + timeout = litellm._calculate_retry_after( + remaining_retries=num_retries, max_retries=num_retries + ) await asyncio.sleep(timeout) - else: + else: raise original_exception - + ## LOGGING if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) - + for current_attempt in range(num_retries): - self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") + self.print_verbose( + f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}" + ) try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = await original_function(*args, **kwargs) - if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines + if inspect.iscoroutinefunction( + response + ): # async errors are often returned as coroutines response = await response return response - - except Exception as e: + + except Exception as e: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) remaining_retries = num_retries - current_attempt - if "No models available" in str(e): - timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries, min_timeout=1) + if "No models available" in str(e): + timeout = litellm._calculate_retry_after( + remaining_retries=remaining_retries, + max_retries=num_retries, + min_timeout=1, + ) await asyncio.sleep(timeout) - elif hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code): + elif ( + hasattr(e, "status_code") + and hasattr(e, "response") + and litellm._should_retry(status_code=e.status_code) + ): if hasattr(e.response, "headers"): - timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries, response_headers=e.response.headers) + timeout = litellm._calculate_retry_after( + remaining_retries=remaining_retries, + max_retries=num_retries, + response_headers=e.response.headers, + ) else: - timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries) + timeout = litellm._calculate_retry_after( + remaining_retries=remaining_retries, + max_retries=num_retries, + ) await asyncio.sleep(timeout) - else: + else: raise e raise original_exception - - def function_with_fallbacks(self, *args, **kwargs): + + def function_with_fallbacks(self, *args, **kwargs): """ Try calling the function_with_retries If it fails after num_retries, fall back to another model group """ model_group = kwargs.get("model") fallbacks = kwargs.get("fallbacks", self.fallbacks) - context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks) - try: + context_window_fallbacks = kwargs.get( + "context_window_fallbacks", self.context_window_fallbacks + ) + try: response = self.function_with_retries(*args, **kwargs) return response except Exception as e: original_exception = e self.print_verbose(f"An exception occurs {original_exception}") - try: - self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}") - if isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None: + try: + self.print_verbose( + f"Trying to fallback b/w models. Initial model group: {model_group}" + ) + if ( + isinstance(e, litellm.ContextWindowExceededError) + and context_window_fallbacks is not None + ): fallback_model_group = None - for item in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] + for ( + item + ) in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break - - if fallback_model_group is None: + + if fallback_model_group is None: raise original_exception - - for mg in fallback_model_group: + + for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment """ @@ -620,10 +828,10 @@ class Router: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs["model"] = mg response = self.function_with_fallbacks(*args, **kwargs) - return response - except Exception as e: + return response + except Exception as e: pass - elif fallbacks is not None: + elif fallbacks is not None: self.print_verbose(f"inside model fallbacks: {fallbacks}") fallback_model_group = None for item in fallbacks: @@ -631,10 +839,10 @@ class Router: fallback_model_group = item[model_group] break - if fallback_model_group is None: + if fallback_model_group is None: raise original_exception - - for mg in fallback_model_group: + + for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment """ @@ -643,107 +851,158 @@ class Router: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs["model"] = mg response = self.function_with_fallbacks(*args, **kwargs) - return response - except Exception as e: + return response + except Exception as e: raise e - except Exception as e: + except Exception as e: raise e raise original_exception - - def function_with_retries(self, *args, **kwargs): + + def function_with_retries(self, *args, **kwargs): """ - Try calling the model 3 times. Shuffle between available deployments. + Try calling the model 3 times. Shuffle between available deployments. """ - self.print_verbose(f"Inside function with retries: args - {args}; kwargs - {kwargs}") + self.print_verbose( + f"Inside function with retries: args - {args}; kwargs - {kwargs}" + ) original_function = kwargs.pop("original_function") num_retries = kwargs.pop("num_retries") fallbacks = kwargs.pop("fallbacks", self.fallbacks) - context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks) - try: + context_window_fallbacks = kwargs.pop( + "context_window_fallbacks", self.context_window_fallbacks + ) + try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = original_function(*args, **kwargs) return response - except Exception as e: + except Exception as e: original_exception = e self.print_verbose(f"num retries in function with retries: {num_retries}") ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR - if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None) - or (isinstance(original_exception, openai.RateLimitError) and fallbacks is not None)): + if ( + isinstance(original_exception, litellm.ContextWindowExceededError) + and context_window_fallbacks is None + ) or ( + isinstance(original_exception, openai.RateLimitError) + and fallbacks is not None + ): raise original_exception ## LOGGING if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) ### RETRY for current_attempt in range(num_retries): - self.print_verbose(f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}") + self.print_verbose( + f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}" + ) try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = original_function(*args, **kwargs) return response - except Exception as e: + except Exception as e: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) remaining_retries = num_retries - current_attempt - if "No models available" in str(e): - timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries, min_timeout=1) + if "No models available" in str(e): + timeout = litellm._calculate_retry_after( + remaining_retries=remaining_retries, + max_retries=num_retries, + min_timeout=1, + ) time.sleep(timeout) - elif hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code): + elif ( + hasattr(e, "status_code") + and hasattr(e, "response") + and litellm._should_retry(status_code=e.status_code) + ): if hasattr(e.response, "headers"): - timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries, response_headers=e.response.headers) + timeout = litellm._calculate_retry_after( + remaining_retries=remaining_retries, + max_retries=num_retries, + response_headers=e.response.headers, + ) else: - timeout = litellm._calculate_retry_after(remaining_retries=remaining_retries, max_retries=num_retries) + timeout = litellm._calculate_retry_after( + remaining_retries=remaining_retries, + max_retries=num_retries, + ) time.sleep(timeout) - else: + else: raise e raise original_exception ### HELPER FUNCTIONS - + def deployment_callback( self, - kwargs, # kwargs to completion - completion_response, # response from completion - start_time, end_time # start/end time + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, + end_time, # start/end time ): """ Function LiteLLM submits a callback to after a successful completion. Purpose of this is to update TPM/RPM usage per model """ - deployment_id = kwargs.get("litellm_params", {}).get("model_info", {}).get("id", None) - model_name = kwargs.get('model', None) # i.e. gpt35turbo - custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure + deployment_id = ( + kwargs.get("litellm_params", {}).get("model_info", {}).get("id", None) + ) + model_name = kwargs.get("model", None) # i.e. gpt35turbo + custom_llm_provider = kwargs.get("litellm_params", {}).get( + "custom_llm_provider", None + ) # i.e. azure if custom_llm_provider: model_name = f"{custom_llm_provider}/{model_name}" - if kwargs["stream"] is True: + if kwargs["stream"] is True: if kwargs.get("complete_streaming_response"): - total_tokens = kwargs.get("complete_streaming_response")['usage']['total_tokens'] + total_tokens = kwargs.get("complete_streaming_response")["usage"][ + "total_tokens" + ] self._set_deployment_usage(deployment_id, total_tokens) - else: - total_tokens = completion_response['usage']['total_tokens'] + else: + total_tokens = completion_response["usage"]["total_tokens"] self._set_deployment_usage(deployment_id, total_tokens) - - self.deployment_latency_map[model_name] = (end_time - start_time).total_seconds() + + self.deployment_latency_map[model_name] = ( + end_time - start_time + ).total_seconds() def deployment_callback_on_failure( - self, - kwargs, # kwargs to completion - completion_response, # response from completion - start_time, end_time # start/end time + self, + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, + end_time, # start/end time ): - try: + try: exception = kwargs.get("exception", None) exception_type = type(exception) - exception_status = getattr(exception, 'status_code', "") - exception_cause = getattr(exception, '__cause__', "") - exception_message = getattr(exception, 'message', "") - exception_str = str(exception_type) + "Status: " + str(exception_status) + "Message: " + str(exception_cause) + str(exception_message) + "Full exception" + str(exception) - model_name = kwargs.get('model', None) # i.e. gpt35turbo - custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure - metadata = kwargs.get("litellm_params", {}).get('metadata', None) - deployment_id = kwargs.get("litellm_params", {}).get("model_info", {}).get("id", None) - self._set_cooldown_deployments(deployment_id) # setting deployment_id in cooldown deployments - if metadata: + exception_status = getattr(exception, "status_code", "") + exception_cause = getattr(exception, "__cause__", "") + exception_message = getattr(exception, "message", "") + exception_str = ( + str(exception_type) + + "Status: " + + str(exception_status) + + "Message: " + + str(exception_cause) + + str(exception_message) + + "Full exception" + + str(exception) + ) + model_name = kwargs.get("model", None) # i.e. gpt35turbo + custom_llm_provider = kwargs.get("litellm_params", {}).get( + "custom_llm_provider", None + ) # i.e. azure + metadata = kwargs.get("litellm_params", {}).get("metadata", None) + deployment_id = ( + kwargs.get("litellm_params", {}).get("model_info", {}).get("id", None) + ) + self._set_cooldown_deployments( + deployment_id + ) # setting deployment_id in cooldown deployments + if metadata: deployment = metadata.get("deployment", None) deployment_exceptions = self.model_exception_map.get(deployment, []) deployment_exceptions.append(exception_str) @@ -751,70 +1010,81 @@ class Router: self.print_verbose("\nEXCEPTION FOR DEPLOYMENTS\n") self.print_verbose(self.model_exception_map) for model in self.model_exception_map: - self.print_verbose(f"Model {model} had {len(self.model_exception_map[model])} exception") + self.print_verbose( + f"Model {model} had {len(self.model_exception_map[model])} exception" + ) if custom_llm_provider: model_name = f"{custom_llm_provider}/{model_name}" - + except Exception as e: raise e - def log_retry(self, kwargs: dict, e: Exception) -> dict: + def log_retry(self, kwargs: dict, e: Exception) -> dict: """ When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing """ - try: + try: # Log failed model as the previous model - previous_model = {"exception_type": type(e).__name__, "exception_string": str(e)} - for k, v in kwargs.items(): # log everything in kwargs except the old previous_models value - prevent nesting + previous_model = { + "exception_type": type(e).__name__, + "exception_string": str(e), + } + for ( + k, + v, + ) in ( + kwargs.items() + ): # log everything in kwargs except the old previous_models value - prevent nesting if k != "metadata": previous_model[k] = v - elif k == "metadata" and isinstance(v, dict): - previous_model["metadata"] = {} # type: ignore - for metadata_k, metadata_v in kwargs['metadata'].items(): - if metadata_k != "previous_models": - previous_model[k][metadata_k] = metadata_v # type: ignore + elif k == "metadata" and isinstance(v, dict): + previous_model["metadata"] = {} # type: ignore + for metadata_k, metadata_v in kwargs["metadata"].items(): + if metadata_k != "previous_models": + previous_model[k][metadata_k] = metadata_v # type: ignore self.previous_models.append(previous_model) kwargs["metadata"]["previous_models"] = self.previous_models return kwargs - except Exception as e: + except Exception as e: raise e - def _set_cooldown_deployments(self, - deployment: Optional[str]=None): + def _set_cooldown_deployments(self, deployment: Optional[str] = None): """ Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute """ if deployment is None: return - + current_minute = datetime.now().strftime("%H-%M") # get current fails for deployment - # update the number of failed calls - # if it's > allowed fails - # cooldown deployment + # update the number of failed calls + # if it's > allowed fails + # cooldown deployment current_fails = self.failed_calls.get_cache(key=deployment) or 0 updated_fails = current_fails + 1 - self.print_verbose(f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}") - if updated_fails > self.allowed_fails: + self.print_verbose( + f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}" + ) + if updated_fails > self.allowed_fails: # get the current cooldown list for that minute - cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls + cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls cached_value = self.cache.get_cache(key=cooldown_key) self.print_verbose(f"adding {deployment} to cooldown models") # update value try: - if deployment in cached_value: + if deployment in cached_value: pass - else: + else: cached_value = cached_value + [deployment] # save updated value - self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=1) + self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=1) except: cached_value = [deployment] # save updated value - self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=1) + self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=1) else: - self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=1) + self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=1) def _get_cooldown_deployments(self): """ @@ -832,10 +1102,12 @@ class Router: self.print_verbose(f"retrieve cooldown models: {cooldown_models}") return cooldown_models - def get_usage_based_available_deployment(self, - model: str, - messages: Optional[List[Dict[str, str]]] = None, - input: Optional[Union[str, List]] = None): + def get_usage_based_available_deployment( + self, + model: str, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + ): """ Returns a deployment with the lowest TPM/RPM usage. """ @@ -861,10 +1133,9 @@ class Router: # ---------------------- lowest_tpm = float("inf") deployment = None - + # load model context map models_context_map = litellm.model_cost - # return deployment with lowest tpm usage for item in potential_deployments: @@ -873,8 +1144,12 @@ class Router: if item_tpm == 0: return item - elif ("tpm" in item and item_tpm + token_count > item["tpm"] - or "rpm" in item and item_rpm + 1 >= item["rpm"]): # if user passed in tpm / rpm in the model_list + elif ( + "tpm" in item + and item_tpm + token_count > item["tpm"] + or "rpm" in item + and item_rpm + 1 >= item["rpm"] + ): # if user passed in tpm / rpm in the model_list continue elif item_tpm < lowest_tpm: lowest_tpm = item_tpm @@ -887,16 +1162,13 @@ class Router: # return model return deployment - def _get_deployment_usage( - self, - deployment_name: str - ): + def _get_deployment_usage(self, deployment_name: str): # ------------ # Setup values # ------------ current_minute = datetime.now().strftime("%H-%M") - tpm_key = f'{deployment_name}:tpm:{current_minute}' - rpm_key = f'{deployment_name}:rpm:{current_minute}' + tpm_key = f"{deployment_name}:tpm:{current_minute}" + rpm_key = f"{deployment_name}:rpm:{current_minute}" # ------------ # Return usage @@ -915,33 +1187,33 @@ class Router: except: cached_value = increment_value # save updated value - self.cache.set_cache(value=cached_value, key=key, ttl=self.default_cache_time_seconds) + self.cache.set_cache( + value=cached_value, key=key, ttl=self.default_cache_time_seconds + ) - def _set_deployment_usage( - self, - model_name: str, - total_tokens: int - ): + def _set_deployment_usage(self, model_name: str, total_tokens: int): # ------------ # Setup values # ------------ current_minute = datetime.now().strftime("%H-%M") - tpm_key = f'{model_name}:tpm:{current_minute}' - rpm_key = f'{model_name}:rpm:{current_minute}' + tpm_key = f"{model_name}:tpm:{current_minute}" + rpm_key = f"{model_name}:rpm:{current_minute}" # ------------ # Update usage # ------------ self.increment(tpm_key, total_tokens) self.increment(rpm_key, 1) - + def _start_health_check_thread(self): """ Starts a separate thread to perform health checks periodically. """ - health_check_thread = threading.Thread(target=self._perform_health_checks, daemon=True) + health_check_thread = threading.Thread( + target=self._perform_health_checks, daemon=True + ) health_check_thread.start() - + def _perform_health_checks(self): """ Periodically performs health checks on the servers. @@ -951,27 +1223,31 @@ class Router: self.healthy_deployments = self._health_check() # Adjust the time interval based on your needs time.sleep(15) - + def _health_check(self): """ Performs a health check on the deployments Returns the list of healthy deployments """ healthy_deployments = [] - for deployment in self.model_list: + for deployment in self.model_list: litellm_args = deployment["litellm_params"] - try: + try: start_time = time.time() - litellm.completion(messages=[{"role": "user", "content": ""}], max_tokens=1, **litellm_args) # hit the server with a blank message to see how long it takes to respond - end_time = time.time() + litellm.completion( + messages=[{"role": "user", "content": ""}], + max_tokens=1, + **litellm_args, + ) # hit the server with a blank message to see how long it takes to respond + end_time = time.time() response_time = end_time - start_time logging.debug(f"response_time: {response_time}") healthy_deployments.append((deployment, response_time)) healthy_deployments.sort(key=lambda x: x[1]) - except Exception as e: + except Exception as e: pass return healthy_deployments - + def weighted_shuffle_by_latency(self, items): # Sort the items by latency sorted_items = sorted(items, key=lambda x: x[1]) @@ -980,18 +1256,19 @@ class Router: # Calculate the sum of all latencies total_latency = sum(latencies) # Calculate the weight for each latency (lower latency = higher weight) - weights = [total_latency-latency for latency in latencies] + weights = [total_latency - latency for latency in latencies] # Get a weighted random item - if sum(weights) == 0: + if sum(weights) == 0: chosen_item = random.choice(sorted_items)[0] - else: + else: chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0] return chosen_item def set_model_list(self, model_list: list): self.model_list = copy.deepcopy(model_list) - # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works + # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works import os + for model in self.model_list: litellm_params = model.get("litellm_params", {}) model_name = litellm_params.get("model") @@ -1001,11 +1278,15 @@ class Router: model["model_info"] = model_info #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## custom_llm_provider = litellm_params.get("custom_llm_provider") - custom_llm_provider = custom_llm_provider or model_name.split("/",1)[0] or "" + custom_llm_provider = ( + custom_llm_provider or model_name.split("/", 1)[0] or "" + ) default_api_base = None default_api_key = None if custom_llm_provider in litellm.openai_compatible_providers: - _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(model=model_name) + _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( + model=model_name + ) default_api_base = api_base default_api_key = api_key if ( @@ -1019,7 +1300,7 @@ class Router: ): # glorified / complicated reading of configs # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env - # we do this here because we init clients for Azure, OpenAI and we need to set the right key + # we do this here because we init clients for Azure, OpenAI and we need to set the right key api_key = litellm_params.get("api_key") or default_api_key if api_key and api_key.startswith("os.environ/"): api_key_env_name = api_key.replace("os.environ/", "") @@ -1028,7 +1309,9 @@ class Router: api_base = litellm_params.get("api_base") base_url = litellm_params.get("base_url") - api_base = api_base or base_url or default_api_base # allow users to pass in `api_base` or `base_url` for azure + api_base = ( + api_base or base_url or default_api_base + ) # allow users to pass in `api_base` or `base_url` for azure if api_base and api_base.startswith("os.environ/"): api_base_env_name = api_base.replace("os.environ/", "") api_base = litellm.get_secret(api_base_env_name) @@ -1046,25 +1329,33 @@ class Router: timeout = litellm.get_secret(timeout_env_name) litellm_params["timeout"] = timeout - stream_timeout = litellm_params.pop("stream_timeout", timeout) # if no stream_timeout is set, default to timeout - if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"): + stream_timeout = litellm_params.pop( + "stream_timeout", timeout + ) # if no stream_timeout is set, default to timeout + if isinstance(stream_timeout, str) and stream_timeout.startswith( + "os.environ/" + ): stream_timeout_env_name = stream_timeout.replace("os.environ/", "") stream_timeout = litellm.get_secret(stream_timeout_env_name) litellm_params["stream_timeout"] = stream_timeout max_retries = litellm_params.pop("max_retries", 2) - if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): + if isinstance(max_retries, str) and max_retries.startswith( + "os.environ/" + ): max_retries_env_name = max_retries.replace("os.environ/", "") max_retries = litellm.get_secret(max_retries_env_name) litellm_params["max_retries"] = max_retries - + if "azure" in model_name: if api_base is None: - raise ValueError(f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}") + raise ValueError( + f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}" + ) if api_version is None: api_version = "2023-07-01-preview" - if "gateway.ai.cloudflare.com" in api_base: - if not api_base.endswith("/"): + if "gateway.ai.cloudflare.com" in api_base: + if not api_base.endswith("/"): api_base += "/" azure_model = model_name.replace("azure/", "") api_base += f"{azure_model}" @@ -1073,14 +1364,14 @@ class Router: base_url=api_base, api_version=api_version, timeout=timeout, - max_retries=max_retries + max_retries=max_retries, ) model["client"] = openai.AzureOpenAI( api_key=api_key, base_url=api_base, api_version=api_version, timeout=timeout, - max_retries=max_retries + max_retries=max_retries, ) # streaming clients can have diff timeouts @@ -1089,24 +1380,28 @@ class Router: base_url=api_base, api_version=api_version, timeout=stream_timeout, - max_retries=max_retries + max_retries=max_retries, ) model["stream_client"] = openai.AzureOpenAI( api_key=api_key, base_url=api_base, api_version=api_version, timeout=stream_timeout, - max_retries=max_retries + max_retries=max_retries, ) else: - self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}") + self.print_verbose( + f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}" + ) model["async_client"] = openai.AsyncAzureOpenAI( api_key=api_key, azure_endpoint=api_base, api_version=api_version, timeout=timeout, max_retries=max_retries, - http_client=httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) # type: ignore + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport(), + ), # type: ignore ) model["client"] = openai.AzureOpenAI( api_key=api_key, @@ -1114,7 +1409,9 @@ class Router: api_version=api_version, timeout=timeout, max_retries=max_retries, - http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore + http_client=httpx.Client( + transport=CustomHTTPTransport(), + ), # type: ignore ) # streaming clients should have diff timeouts model["stream_async_client"] = openai.AsyncAzureOpenAI( @@ -1130,22 +1427,24 @@ class Router: azure_endpoint=api_base, api_version=api_version, timeout=stream_timeout, - max_retries=max_retries + max_retries=max_retries, ) - + else: - self.print_verbose(f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}") + self.print_verbose( + f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}" + ) model["async_client"] = openai.AsyncOpenAI( api_key=api_key, base_url=api_base, timeout=timeout, - max_retries=max_retries + max_retries=max_retries, ) model["client"] = openai.OpenAI( api_key=api_key, base_url=api_base, timeout=timeout, - max_retries=max_retries + max_retries=max_retries, ) # streaming clients should have diff timeouts @@ -1153,7 +1452,7 @@ class Router: api_key=api_key, base_url=api_base, timeout=stream_timeout, - max_retries=max_retries + max_retries=max_retries, ) # streaming clients should have diff timeouts @@ -1161,7 +1460,7 @@ class Router: api_key=api_key, base_url=api_base, timeout=stream_timeout, - max_retries=max_retries + max_retries=max_retries, ) ############ End of initializing Clients for OpenAI/Azure ################### @@ -1171,9 +1470,15 @@ class Router: ############ Users can either pass tpm/rpm as a litellm_param or a router param ########### # for get_available_deployment, we use the litellm_param["rpm"] # in this snippet we also set rpm to be a litellm_param - if model["litellm_params"].get("rpm") is None and model.get("rpm") is not None: + if ( + model["litellm_params"].get("rpm") is None + and model.get("rpm") is not None + ): model["litellm_params"]["rpm"] = model.get("rpm") - if model["litellm_params"].get("tpm") is None and model.get("tpm") is not None: + if ( + model["litellm_params"].get("tpm") is None + and model.get("tpm") is not None + ): model["litellm_params"]["tpm"] = model.get("tpm") self.model_names = [m["model_name"] for m in model_list] @@ -1204,19 +1509,20 @@ class Router: else: return deployment.get("client", None) - def print_verbose(self, print_statement): + def print_verbose(self, print_statement): try: - if self.set_verbose or litellm.set_verbose: - print(f"LiteLLM.Router: {print_statement}") # noqa + if self.set_verbose or litellm.set_verbose: + print(f"LiteLLM.Router: {print_statement}") # noqa except: pass - def get_available_deployment(self, - model: str, - messages: Optional[List[Dict[str, str]]] = None, - input: Optional[Union[str, List]] = None, - specific_deployment: Optional[bool] = False - ): + def get_available_deployment( + self, + model: str, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + specific_deployment: Optional[bool] = False, + ): """ Returns the deployment based on routing strategy """ @@ -1225,69 +1531,79 @@ class Router: # When this was no explicit we had several issues with fallbacks timing out if specific_deployment == True: # users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment - for deployment in self.model_list: + for deployment in self.model_list: deployment_model = deployment.get("litellm_params").get("model") - if deployment_model == model: + if deployment_model == model: # User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2 # return the first deployment where the `model` matches the specificed deployment name return deployment - raise ValueError(f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}") + raise ValueError( + f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}" + ) # check if aliases set on litellm model alias map if model in self.model_group_alias: - self.print_verbose(f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}") + self.print_verbose( + f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}" + ) model = self.model_group_alias[model] ## get healthy deployments - ### get all deployments + ### get all deployments healthy_deployments = [m for m in self.model_list if m["model_name"] == model] - if len(healthy_deployments) == 0: - # check if the user sent in a deployment name instead - healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model] + if len(healthy_deployments) == 0: + # check if the user sent in a deployment name instead + healthy_deployments = [ + m for m in self.model_list if m["litellm_params"]["model"] == model + ] self.print_verbose(f"initial list of deployments: {healthy_deployments}") - # filter out the deployments currently cooling down - deployments_to_remove = [] + # filter out the deployments currently cooling down + deployments_to_remove = [] # cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"] - cooldown_deployments = self._get_cooldown_deployments() + cooldown_deployments = self._get_cooldown_deployments() self.print_verbose(f"cooldown deployments: {cooldown_deployments}") # Find deployments in model_list whose model_id is cooling down - for deployment in healthy_deployments: + for deployment in healthy_deployments: deployment_id = deployment["model_info"]["id"] - if deployment_id in cooldown_deployments: + if deployment_id in cooldown_deployments: deployments_to_remove.append(deployment) # remove unhealthy deployments from healthy deployments for deployment in deployments_to_remove: healthy_deployments.remove(deployment) - self.print_verbose(f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}") - if len(healthy_deployments) == 0: + self.print_verbose( + f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}" + ) + if len(healthy_deployments) == 0: raise ValueError("No models available") if litellm.model_alias_map and model in litellm.model_alias_map: model = litellm.model_alias_map[ model ] # update the model to the actual value if an alias has been passed in if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None: - deployments = self.leastbusy_logger.get_available_deployments(model_group=model) + deployments = self.leastbusy_logger.get_available_deployments( + model_group=model + ) # pick least busy deployment - min_traffic = float('inf') + min_traffic = float("inf") min_deployment = None - for k, v in deployments.items(): + for k, v in deployments.items(): if v < min_traffic: min_traffic = v min_deployment = k ############## No Available Deployments passed, we do a random pick ################# - if min_deployment is None: + if min_deployment is None: min_deployment = random.choice(healthy_deployments) ############## Available Deployments passed, we find the relevant item ################# - else: - for m in healthy_deployments: + else: + for m in healthy_deployments: if m["model_info"]["id"] == min_deployment: return m min_deployment = random.choice(healthy_deployments) - return min_deployment - elif self.routing_strategy == "simple-shuffle": + return min_deployment + elif self.routing_strategy == "simple-shuffle": # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm ############## Check if we can do a RPM/TPM based weighted pick ################# rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) @@ -1321,30 +1637,33 @@ class Router: ############## No RPM/TPM passed, we do a random pick ################# item = random.choice(healthy_deployments) return item or item[0] - elif self.routing_strategy == "latency-based-routing": + elif self.routing_strategy == "latency-based-routing": returned_item = None - lowest_latency = float('inf') + lowest_latency = float("inf") ### shuffles with priority for lowest latency # items_with_latencies = [('A', 10), ('B', 20), ('C', 30), ('D', 40)] - items_with_latencies = [] + items_with_latencies = [] for item in healthy_deployments: - items_with_latencies.append((item, self.deployment_latency_map[item["litellm_params"]["model"]])) + items_with_latencies.append( + (item, self.deployment_latency_map[item["litellm_params"]["model"]]) + ) returned_item = self.weighted_shuffle_by_latency(items_with_latencies) return returned_item - elif self.routing_strategy == "usage-based-routing": - return self.get_usage_based_available_deployment(model=model, messages=messages, input=input) - + elif self.routing_strategy == "usage-based-routing": + return self.get_usage_based_available_deployment( + model=model, messages=messages, input=input + ) + raise ValueError("No models available.") def flush_cache(self): litellm.cache = None self.cache.flush_cache() - - def reset(self): + + def reset(self): ## clean up on close - litellm.success_callback = [] - litellm.__async_success_callback = [] - litellm.failure_callback = [] - litellm._async_failure_callback = [] - self.flush_cache() - + litellm.success_callback = [] + litellm.__async_success_callback = [] + litellm.failure_callback = [] + litellm._async_failure_callback = [] + self.flush_cache() diff --git a/litellm/router_strategy/least_busy.py b/litellm/router_strategy/least_busy.py index 0080e3fa8..d06c5b309 100644 --- a/litellm/router_strategy/least_busy.py +++ b/litellm/router_strategy/least_busy.py @@ -1,96 +1,121 @@ #### What this does #### # identifies least busy deployment -# How is this achieved? +# How is this achieved? # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"} # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic} -# - use litellm.success + failure callbacks to log when a request completed +# - use litellm.success + failure callbacks to log when a request completed # - in get_available_deployment, for a given model group name -> pick based on traffic import dotenv, os, requests from typing import Optional + dotenv.load_dotenv() # Loading env variables using dotenv import traceback from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger -class LeastBusyLoggingHandler(CustomLogger): +class LeastBusyLoggingHandler(CustomLogger): def __init__(self, router_cache: DualCache): self.router_cache = router_cache - self.mapping_deployment_to_id: dict = {} - + self.mapping_deployment_to_id: dict = {} def log_pre_api_call(self, model, messages, kwargs): """ - Log when a model is being used. + Log when a model is being used. - Caching based on model group. + Caching based on model group. """ - try: - - if kwargs['litellm_params'].get('metadata') is None: + try: + if kwargs["litellm_params"].get("metadata") is None: pass - else: - deployment = kwargs['litellm_params']['metadata'].get('deployment', None) - model_group = kwargs['litellm_params']['metadata'].get('model_group', None) - id = kwargs['litellm_params'].get('model_info', {}).get('id', None) + else: + deployment = kwargs["litellm_params"]["metadata"].get( + "deployment", None + ) + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) if deployment is None or model_group is None or id is None: return - + # map deployment to id self.mapping_deployment_to_id[deployment] = id - + request_count_api_key = f"{model_group}_request_count" # update cache - request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} - request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1 - self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) + request_count_dict[deployment] = ( + request_count_dict.get(deployment, 0) + 1 + ) + self.router_cache.set_cache( + key=request_count_api_key, value=request_count_dict + ) except Exception as e: pass async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: - if kwargs['litellm_params'].get('metadata') is None: + if kwargs["litellm_params"].get("metadata") is None: pass - else: - deployment = kwargs['litellm_params']['metadata'].get('deployment', None) - model_group = kwargs['litellm_params']['metadata'].get('model_group', None) + else: + deployment = kwargs["litellm_params"]["metadata"].get( + "deployment", None + ) + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) if deployment is None or model_group is None: return - - + request_count_api_key = f"{model_group}_request_count" # decrement count in cache - request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) request_count_dict[deployment] = request_count_dict.get(deployment) - self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) + self.router_cache.set_cache( + key=request_count_api_key, value=request_count_dict + ) except Exception as e: pass async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: - if kwargs['litellm_params'].get('metadata') is None: + if kwargs["litellm_params"].get("metadata") is None: pass - else: - deployment = kwargs['litellm_params']['metadata'].get('deployment', None) - model_group = kwargs['litellm_params']['metadata'].get('model_group', None) + else: + deployment = kwargs["litellm_params"]["metadata"].get( + "deployment", None + ) + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) if deployment is None or model_group is None: return - - + request_count_api_key = f"{model_group}_request_count" # decrement count in cache - request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) request_count_dict[deployment] = request_count_dict.get(deployment) - self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) + self.router_cache.set_cache( + key=request_count_api_key, value=request_count_dict + ) except Exception as e: pass def get_available_deployments(self, model_group: str): request_count_api_key = f"{model_group}_request_count" - request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} + request_count_dict = ( + self.router_cache.get_cache(key=request_count_api_key) or {} + ) # map deployment to id return_dict = {} for key, value in request_count_dict.items(): return_dict[self.mapping_deployment_to_id[key]] = value - return return_dict \ No newline at end of file + return return_dict diff --git a/litellm/tests/conftest.py b/litellm/tests/conftest.py index 411da8023..6b0df0f9a 100644 --- a/litellm/tests/conftest.py +++ b/litellm/tests/conftest.py @@ -2,6 +2,7 @@ import pytest, sys, os import importlib + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path @@ -11,24 +12,30 @@ import litellm @pytest.fixture(scope="function", autouse=True) def setup_and_teardown(): """ - This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained. + This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained. """ curr_dir = os.getcwd() # Get the current working directory - sys.path.insert(0, os.path.abspath("../..")) # Adds the project directory to the system path + sys.path.insert( + 0, os.path.abspath("../..") + ) # Adds the project directory to the system path import litellm + importlib.reload(litellm) print(litellm) # from litellm import Router, completion, aembedding, acompletion, embedding yield + def pytest_collection_modifyitems(config, items): # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests - custom_logger_tests = [item for item in items if 'custom_logger' in item.parent.name] - other_tests = [item for item in items if 'custom_logger' not in item.parent.name] + custom_logger_tests = [ + item for item in items if "custom_logger" in item.parent.name + ] + other_tests = [item for item in items if "custom_logger" not in item.parent.name] # Sort tests based on their names custom_logger_tests.sort(key=lambda x: x.name) other_tests.sort(key=lambda x: x.name) # Reorder the items list - items[:] = custom_logger_tests + other_tests \ No newline at end of file + items[:] = custom_logger_tests + other_tests diff --git a/litellm/tests/test_acooldowns_router.py b/litellm/tests/test_acooldowns_router.py index d1a33d10b..acd884c82 100644 --- a/litellm/tests/test_acooldowns_router.py +++ b/litellm/tests/test_acooldowns_router.py @@ -4,6 +4,7 @@ import sys, os, time import traceback, asyncio import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path @@ -11,53 +12,62 @@ import litellm from litellm import Router import concurrent from dotenv import load_dotenv + load_dotenv() -model_list = [{ # list of model deployments - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": "bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800, - }, - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "tpm": 1000000, - "rpm": 9000 - } +model_list = [ + { # list of model deployments + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000, + }, ] -kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],} +kwargs = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hey, how's it going?"}], +} -def test_multiple_deployments_sync(): - import concurrent, time - litellm.set_verbose=False - results = [] - router = Router(model_list=model_list, - redis_host=os.getenv("REDIS_HOST"), - redis_password=os.getenv("REDIS_PASSWORD"), - redis_port=int(os.getenv("REDIS_PORT")), # type: ignore - routing_strategy="simple-shuffle", - set_verbose=True, - num_retries=1) # type: ignore - try: - for _ in range(3): - response = router.completion(**kwargs) - results.append(response) - print(results) - router.reset() - except Exception as e: - print(f"FAILED TEST!") - pytest.fail(f"An error occurred - {traceback.format_exc()}") +def test_multiple_deployments_sync(): + import concurrent, time + + litellm.set_verbose = False + results = [] + router = Router( + model_list=model_list, + redis_host=os.getenv("REDIS_HOST"), + redis_password=os.getenv("REDIS_PASSWORD"), + redis_port=int(os.getenv("REDIS_PORT")), # type: ignore + routing_strategy="simple-shuffle", + set_verbose=True, + num_retries=1, + ) # type: ignore + try: + for _ in range(3): + response = router.completion(**kwargs) + results.append(response) + print(results) + router.reset() + except Exception as e: + print(f"FAILED TEST!") + pytest.fail(f"An error occurred - {traceback.format_exc()}") + # test_multiple_deployments_sync() @@ -67,13 +77,15 @@ def test_multiple_deployments_parallel(): results = [] futures = {} start_time = time.time() - router = Router(model_list=model_list, - redis_host=os.getenv("REDIS_HOST"), - redis_password=os.getenv("REDIS_PASSWORD"), - redis_port=int(os.getenv("REDIS_PORT")), # type: ignore - routing_strategy="simple-shuffle", - set_verbose=True, - num_retries=1) # type: ignore + router = Router( + model_list=model_list, + redis_host=os.getenv("REDIS_HOST"), + redis_password=os.getenv("REDIS_PASSWORD"), + redis_port=int(os.getenv("REDIS_PORT")), # type: ignore + routing_strategy="simple-shuffle", + set_verbose=True, + num_retries=1, + ) # type: ignore # Assuming you have an executor instance defined somewhere in your code with concurrent.futures.ThreadPoolExecutor() as executor: for _ in range(5): @@ -82,7 +94,11 @@ def test_multiple_deployments_parallel(): # Retrieve the results from the futures while futures: - done, not_done = concurrent.futures.wait(futures.values(), timeout=10, return_when=concurrent.futures.FIRST_COMPLETED) + done, not_done = concurrent.futures.wait( + futures.values(), + timeout=10, + return_when=concurrent.futures.FIRST_COMPLETED, + ) for future in done: try: result = future.result() @@ -98,12 +114,14 @@ def test_multiple_deployments_parallel(): print(results) print(f"ELAPSED TIME: {end_time - start_time}") + # Assuming litellm, router, and executor are defined somewhere in your code + # test_multiple_deployments_parallel() def test_cooldown_same_model_name(): # users could have the same model with different api_base - # example + # example # azure/chatgpt, api_base: 1234 # azure/chatgpt, api_base: 1235 # if 1234 fails, it should only cooldown 1234 and then try with 1235 @@ -118,7 +136,7 @@ def test_cooldown_same_model_name(): "api_key": os.getenv("AZURE_API_KEY"), "api_version": os.getenv("AZURE_API_VERSION"), "api_base": "BAD_API_BASE", - "tpm": 90 + "tpm": 90, }, }, { @@ -128,7 +146,7 @@ def test_cooldown_same_model_name(): "api_key": os.getenv("AZURE_API_KEY"), "api_version": os.getenv("AZURE_API_VERSION"), "api_base": os.getenv("AZURE_API_BASE"), - "tpm": 0.000001 + "tpm": 0.000001, }, }, ] @@ -140,17 +158,12 @@ def test_cooldown_same_model_name(): redis_port=int(os.getenv("REDIS_PORT")), routing_strategy="simple-shuffle", set_verbose=True, - num_retries=3 + num_retries=3, ) # type: ignore response = router.completion( model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "hello this request will pass" - } - ] + messages=[{"role": "user", "content": "hello this request will pass"}], ) print(router.model_list) model_ids = [] @@ -159,10 +172,13 @@ def test_cooldown_same_model_name(): print("\n litellm model ids ", model_ids) # example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960'] - assert model_ids[0] != model_ids[1] # ensure both models have a uuid added, and they have different names + assert ( + model_ids[0] != model_ids[1] + ) # ensure both models have a uuid added, and they have different names print("\ngot response\n", response) except Exception as e: pytest.fail(f"Got unexpected exception on router! - {e}") + test_cooldown_same_model_name() diff --git a/litellm/tests/test_add_function_to_prompt.py b/litellm/tests/test_add_function_to_prompt.py index a5ec53062..932e6edd1 100644 --- a/litellm/tests/test_add_function_to_prompt.py +++ b/litellm/tests/test_add_function_to_prompt.py @@ -9,68 +9,70 @@ sys.path.insert( ) # Adds the parent directory to the system path import litellm -## case 1: set_function_to_prompt not set + +## case 1: set_function_to_prompt not set def test_function_call_non_openai_model(): - try: + try: model = "claude-instant-1" - messages=[{"role": "user", "content": "what's the weather in sf?"}] + messages = [{"role": "user", "content": "what's the weather in sf?"}] functions = [ { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } - }, - "required": ["location"] - } } ] - response = litellm.completion(model=model, messages=messages, functions=functions) - pytest.fail(f'An error occurred') - except Exception as e: + response = litellm.completion( + model=model, messages=messages, functions=functions + ) + pytest.fail(f"An error occurred") + except Exception as e: print(e) pass + test_function_call_non_openai_model() -## case 2: add_function_to_prompt set + +## case 2: add_function_to_prompt set def test_function_call_non_openai_model_litellm_mod_set(): litellm.add_function_to_prompt = True - try: + try: model = "claude-instant-1" - messages=[{"role": "user", "content": "what's the weather in sf?"}] + messages = [{"role": "user", "content": "what's the weather in sf?"}] functions = [ { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } - }, - "required": ["location"] - } } ] - response = litellm.completion(model=model, messages=messages, functions=functions) - print(f'response: {response}') - except Exception as e: - pytest.fail(f'An error occurred {e}') + response = litellm.completion( + model=model, messages=messages, functions=functions + ) + print(f"response: {response}") + except Exception as e: + pytest.fail(f"An error occurred {e}") -# test_function_call_non_openai_model_litellm_mod_set() \ No newline at end of file + +# test_function_call_non_openai_model_litellm_mod_set() diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 910bd8f6b..eb620273f 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -1,4 +1,3 @@ - import sys, os import traceback from dotenv import load_dotenv @@ -8,7 +7,7 @@ import os, io sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system path import pytest, asyncio import litellm from litellm import embedding, completion, completion_cost, Timeout, acompletion @@ -20,18 +19,18 @@ import tempfile litellm.num_retries = 3 litellm.cache = None user_message = "Write a short poem about the sky" -messages = [{"content": user_message, "role": "user"}] +messages = [{"content": user_message, "role": "user"}] def load_vertex_ai_credentials(): # Define the path to the vertex_key.json file print("loading vertex ai credentials") filepath = os.path.dirname(os.path.abspath(__file__)) - vertex_key_path = filepath + '/vertex_key.json' + vertex_key_path = filepath + "/vertex_key.json" # Read the existing content of the file or create an empty dictionary try: - with open(vertex_key_path, 'r') as file: + with open(vertex_key_path, "r") as file: # Read the file content print("Read vertexai file path") content = file.read() @@ -55,13 +54,13 @@ def load_vertex_ai_credentials(): service_account_key_data["private_key"] = private_key # Create a temporary file - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: # Write the updated content to the temporary file json.dump(service_account_key_data, temp_file, indent=2) - # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name) + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) + @pytest.mark.asyncio async def get_response(): @@ -89,43 +88,80 @@ def test_vertex_ai(): import random load_vertex_ai_credentials() - test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models - litellm.set_verbose=False + test_models = ( + litellm.vertex_chat_models + + litellm.vertex_code_chat_models + + litellm.vertex_text_models + + litellm.vertex_code_text_models + ) + litellm.set_verbose = False litellm.vertex_project = "reliablekeys" test_models = random.sample(test_models, 1) - test_models += litellm.vertex_language_models # always test gemini-pro + test_models += litellm.vertex_language_models # always test gemini-pro for model in test_models: try: - if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: + if model in [ + "code-gecko", + "code-gecko@001", + "code-gecko@002", + "code-gecko@latest", + "code-bison@001", + "text-bison@001", + ]: # our account does not have access to this model continue print("making request", model) - response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}], temperature=0.7) + response = completion( + model=model, + messages=[{"role": "user", "content": "hi"}], + temperature=0.7, + ) print("\nModel Response", response) print(response) assert type(response.choices[0].message.content) == str assert len(response.choices[0].message.content) > 1 except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_vertex_ai() + def test_vertex_ai_stream(): load_vertex_ai_credentials() - litellm.set_verbose=False + litellm.set_verbose = False litellm.vertex_project = "reliablekeys" import random - test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models + test_models = ( + litellm.vertex_chat_models + + litellm.vertex_code_chat_models + + litellm.vertex_text_models + + litellm.vertex_code_text_models + ) test_models = random.sample(test_models, 1) - test_models += litellm.vertex_language_models # always test gemini-pro + test_models += litellm.vertex_language_models # always test gemini-pro for model in test_models: try: - if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: + if model in [ + "code-gecko", + "code-gecko@001", + "code-gecko@002", + "code-gecko@latest", + "code-bison@001", + "text-bison@001", + ]: # our account does not have access to this model continue print("making request", model) - response = completion(model=model, messages=[{"role": "user", "content": "write 10 line code code for saying hi"}], stream=True) + response = completion( + model=model, + messages=[ + {"role": "user", "content": "write 10 line code code for saying hi"} + ], + stream=True, + ) completed_str = "" for chunk in response: print(chunk) @@ -137,47 +173,86 @@ def test_vertex_ai_stream(): assert len(completed_str) > 4 except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_vertex_ai_stream() + + +# test_vertex_ai_stream() + @pytest.mark.asyncio async def test_async_vertexai_response(): import random + load_vertex_ai_credentials() - test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models + test_models = ( + litellm.vertex_chat_models + + litellm.vertex_code_chat_models + + litellm.vertex_text_models + + litellm.vertex_code_text_models + ) test_models = random.sample(test_models, 1) - test_models += litellm.vertex_language_models # always test gemini-pro + test_models += litellm.vertex_language_models # always test gemini-pro for model in test_models: - print(f'model being tested in async call: {model}') - if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: - # our account does not have access to this model - continue + print(f"model being tested in async call: {model}") + if model in [ + "code-gecko", + "code-gecko@001", + "code-gecko@002", + "code-gecko@latest", + "code-bison@001", + "text-bison@001", + ]: + # our account does not have access to this model + continue try: user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] - response = await acompletion(model=model, messages=messages, temperature=0.7, timeout=5) + response = await acompletion( + model=model, messages=messages, temperature=0.7, timeout=5 + ) print(f"response: {response}") - except litellm.Timeout as e: + except litellm.Timeout as e: pass except Exception as e: pytest.fail(f"An exception occurred: {e}") + # asyncio.run(test_async_vertexai_response()) + @pytest.mark.asyncio async def test_async_vertexai_streaming_response(): import random + load_vertex_ai_credentials() - test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models + test_models = ( + litellm.vertex_chat_models + + litellm.vertex_code_chat_models + + litellm.vertex_text_models + + litellm.vertex_code_text_models + ) test_models = random.sample(test_models, 1) - test_models += litellm.vertex_language_models # always test gemini-pro + test_models += litellm.vertex_language_models # always test gemini-pro for model in test_models: - if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: - # our account does not have access to this model - continue + if model in [ + "code-gecko", + "code-gecko@001", + "code-gecko@002", + "code-gecko@latest", + "code-bison@001", + "text-bison@001", + ]: + # our account does not have access to this model + continue try: user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] - response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5, stream=True) + response = await acompletion( + model="gemini-pro", + messages=messages, + temperature=0.7, + timeout=5, + stream=True, + ) print(f"response: {response}") complete_response = "" async for chunk in response: @@ -185,44 +260,46 @@ async def test_async_vertexai_streaming_response(): complete_response += chunk.choices[0].delta.content print(f"complete_response: {complete_response}") assert len(complete_response) > 0 - except litellm.Timeout as e: + except litellm.Timeout as e: pass except Exception as e: print(e) pytest.fail(f"An exception occurred: {e}") + # asyncio.run(test_async_vertexai_streaming_response()) + def test_gemini_pro_vision(): try: load_vertex_ai_credentials() litellm.set_verbose = True - litellm.num_retries=0 + litellm.num_retries = 0 resp = litellm.completion( - model = "vertex_ai/gemini-pro-vision", + model="vertex_ai/gemini-pro-vision", messages=[ { "role": "user", "content": [ - { - "type": "text", - "text": "Whats in this image?" - }, - { - "type": "image_url", - "image_url": { - "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" - } - } - ] + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" + }, + }, + ], } ], ) print(resp) except Exception as e: import traceback + traceback.print_exc() raise e + + # test_gemini_pro_vision() @@ -333,4 +410,4 @@ def test_gemini_pro_vision(): # import traceback # traceback.print_exc() # raise e -# test_gemini_pro_vision_async() \ No newline at end of file +# test_gemini_pro_vision_async() diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index 7a722095e..485e86e7f 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -11,8 +11,10 @@ sys.path.insert( ) # Adds the parent directory to the system path import litellm from litellm import completion, acompletion, acreate + litellm.num_retries = 3 + def test_sync_response(): litellm.set_verbose = False user_message = "Hello, how are you?" @@ -20,35 +22,49 @@ def test_sync_response(): try: response = completion(model="gpt-3.5-turbo", messages=messages, timeout=5) print(f"response: {response}") - except litellm.Timeout as e: + except litellm.Timeout as e: pass except Exception as e: pytest.fail(f"An exception occurred: {e}") + + # test_sync_response() + def test_sync_response_anyscale(): litellm.set_verbose = False user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] try: - response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5) - except litellm.Timeout as e: + response = completion( + model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", + messages=messages, + timeout=5, + ) + except litellm.Timeout as e: pass except Exception as e: pytest.fail(f"An exception occurred: {e}") + + # test_sync_response_anyscale() + def test_async_response_openai(): import asyncio + litellm.set_verbose = True + async def test_get_response(): user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] try: - response = await acompletion(model="gpt-3.5-turbo", messages=messages, timeout=5) + response = await acompletion( + model="gpt-3.5-turbo", messages=messages, timeout=5 + ) print(f"response: {response}") print(f"response ms: {response._response_ms}") - except litellm.Timeout as e: + except litellm.Timeout as e: pass except Exception as e: pytest.fail(f"An exception occurred: {e}") @@ -56,54 +72,75 @@ def test_async_response_openai(): asyncio.run(test_get_response()) + # test_async_response_openai() + def test_async_response_azure(): import asyncio + litellm.set_verbose = True + async def test_get_response(): user_message = "What do you know?" messages = [{"content": user_message, "role": "user"}] try: - response = await acompletion(model="azure/gpt-turbo", messages=messages, base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), api_key=os.getenv("AZURE_FRANCE_API_KEY")) + response = await acompletion( + model="azure/gpt-turbo", + messages=messages, + base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), + api_key=os.getenv("AZURE_FRANCE_API_KEY"), + ) print(f"response: {response}") - except litellm.Timeout as e: + except litellm.Timeout as e: pass except Exception as e: pytest.fail(f"An exception occurred: {e}") asyncio.run(test_get_response()) + # test_async_response_azure() def test_async_anyscale_response(): import asyncio + litellm.set_verbose = True + async def test_get_response(): user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] try: - response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5) + response = await acompletion( + model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", + messages=messages, + timeout=5, + ) # response = await response print(f"response: {response}") - except litellm.Timeout as e: + except litellm.Timeout as e: pass except Exception as e: pytest.fail(f"An exception occurred: {e}") asyncio.run(test_get_response()) + # test_async_anyscale_response() + def test_get_response_streaming(): import asyncio + async def test_async_call(): user_message = "write a short poem in one sentence" messages = [{"content": user_message, "role": "user"}] try: litellm.set_verbose = True - response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5) + response = await acompletion( + model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5 + ) print(type(response)) import inspect @@ -116,29 +153,39 @@ def test_get_response_streaming(): async for chunk in response: token = chunk["choices"][0]["delta"].get("content", "") if token == None: - continue # openai v1.0.0 returns content=None + continue # openai v1.0.0 returns content=None output += token assert output is not None, "output cannot be None." assert isinstance(output, str), "output needs to be of type str" assert len(output) > 0, "Length of output needs to be greater than 0." - print(f'output: {output}') - except litellm.Timeout as e: + print(f"output: {output}") + except litellm.Timeout as e: pass except Exception as e: pytest.fail(f"An exception occurred: {e}") + asyncio.run(test_async_call()) + # test_get_response_streaming() + def test_get_response_non_openai_streaming(): import asyncio + litellm.set_verbose = True litellm.num_retries = 0 + async def test_async_call(): user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] try: - response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True, timeout=5) + response = await acompletion( + model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", + messages=messages, + stream=True, + timeout=5, + ) print(type(response)) import inspect @@ -158,11 +205,13 @@ def test_get_response_non_openai_streaming(): assert output is not None, "output cannot be None." assert isinstance(output, str), "output needs to be of type str" assert len(output) > 0, "Length of output needs to be greater than 0." - except litellm.Timeout as e: + except litellm.Timeout as e: pass except Exception as e: pytest.fail(f"An exception occurred: {e}") return response + asyncio.run(test_async_call()) -# test_get_response_non_openai_streaming() \ No newline at end of file + +# test_get_response_non_openai_streaming() diff --git a/litellm/tests/test_azure_perf.py b/litellm/tests/test_azure_perf.py index 36bfe1d80..9654f1273 100644 --- a/litellm/tests/test_azure_perf.py +++ b/litellm/tests/test_azure_perf.py @@ -3,79 +3,105 @@ import sys, os, time, inspect, asyncio, traceback from datetime import datetime import pytest -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) import openai, litellm, uuid from openai import AsyncAzureOpenAI client = AsyncAzureOpenAI( api_key=os.getenv("AZURE_API_KEY"), - azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore - api_version=os.getenv("AZURE_API_VERSION") + azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore + api_version=os.getenv("AZURE_API_VERSION"), ) model_list = [ - { - "model_name": "azure-test", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION") + { + "model_name": "azure-test", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + }, } - } ] router = litellm.Router(model_list=model_list) + async def _openai_completion(): - try: - start_time = time.time() - response = await client.chat.completions.create( - model="chatgpt-v-2", - messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], - stream=True - ) - time_to_first_token = None - first_token_ts = None - init_chunk = None - async for chunk in response: - if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: - first_token_ts = time.time() - time_to_first_token = first_token_ts - start_time - init_chunk = chunk - end_time = time.time() - print("OpenAI Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time) - return time_to_first_token - except Exception as e: - print(e) - return None + try: + start_time = time.time() + response = await client.chat.completions.create( + model="chatgpt-v-2", + messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], + stream=True, + ) + time_to_first_token = None + first_token_ts = None + init_chunk = None + async for chunk in response: + if ( + time_to_first_token is None + and len(chunk.choices) > 0 + and chunk.choices[0].delta.content is not None + ): + first_token_ts = time.time() + time_to_first_token = first_token_ts - start_time + init_chunk = chunk + end_time = time.time() + print( + "OpenAI Call: ", + init_chunk, + start_time, + first_token_ts, + time_to_first_token, + end_time, + ) + return time_to_first_token + except Exception as e: + print(e) + return None + async def _router_completion(): - try: - start_time = time.time() - response = await router.acompletion( - model="azure-test", - messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], - stream=True - ) - time_to_first_token = None - first_token_ts = None - init_chunk = None - async for chunk in response: - if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: - first_token_ts = time.time() - time_to_first_token = first_token_ts - start_time - init_chunk = chunk - end_time = time.time() - print("Router Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time - first_token_ts) - return time_to_first_token - except Exception as e: - print(e) - return None + try: + start_time = time.time() + response = await router.acompletion( + model="azure-test", + messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], + stream=True, + ) + time_to_first_token = None + first_token_ts = None + init_chunk = None + async for chunk in response: + if ( + time_to_first_token is None + and len(chunk.choices) > 0 + and chunk.choices[0].delta.content is not None + ): + first_token_ts = time.time() + time_to_first_token = first_token_ts - start_time + init_chunk = chunk + end_time = time.time() + print( + "Router Call: ", + init_chunk, + start_time, + first_token_ts, + time_to_first_token, + end_time - first_token_ts, + ) + return time_to_first_token + except Exception as e: + print(e) + return None -async def test_azure_completion_streaming(): + +async def test_azure_completion_streaming(): """ - Test azure streaming call - measure on time to first (non-null) token. + Test azure streaming call - measure on time to first (non-null) token. """ n = 3 # Number of concurrent tasks ## OPENAI AVG. TIME @@ -83,19 +109,20 @@ async def test_azure_completion_streaming(): chat_completions = await asyncio.gather(*tasks) successful_completions = [c for c in chat_completions if c is not None] total_time = 0 - for item in successful_completions: - total_time += item - avg_openai_time = total_time/3 + for item in successful_completions: + total_time += item + avg_openai_time = total_time / 3 ## ROUTER AVG. TIME tasks = [_router_completion() for _ in range(n)] chat_completions = await asyncio.gather(*tasks) successful_completions = [c for c in chat_completions if c is not None] total_time = 0 - for item in successful_completions: - total_time += item - avg_router_time = total_time/3 + for item in successful_completions: + total_time += item + avg_router_time = total_time / 3 ## COMPARE print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}") assert avg_router_time < avg_openai_time + 0.5 -# asyncio.run(test_azure_completion_streaming()) \ No newline at end of file + +# asyncio.run(test_azure_completion_streaming()) diff --git a/litellm/tests/test_bad_params.py b/litellm/tests/test_bad_params.py index 749391bd9..58fe204bd 100644 --- a/litellm/tests/test_bad_params.py +++ b/litellm/tests/test_bad_params.py @@ -5,6 +5,7 @@ import sys, os import traceback import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path @@ -18,6 +19,7 @@ user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] model_val = None + def test_completion_with_no_model(): # test on empty with pytest.raises(ValueError): @@ -32,9 +34,10 @@ def test_completion_with_empty_model(): print(f"error occurred: {e}") pass + # def test_completion_catch_nlp_exception(): # TEMP commented out NLP cloud API is unstable -# try: +# try: # response = completion(model="dolphin", messages=messages, functions=[ # { # "name": "get_current_weather", @@ -56,65 +59,77 @@ def test_completion_with_empty_model(): # } # ]) -# except Exception as e: -# if "Function calling is not supported by nlp_cloud" in str(e): +# except Exception as e: +# if "Function calling is not supported by nlp_cloud" in str(e): # pass # else: # pytest.fail(f'An error occurred {e}') -# test_completion_catch_nlp_exception() +# test_completion_catch_nlp_exception() + def test_completion_invalid_param_cohere(): - try: + try: response = completion(model="command-nightly", messages=messages, top_p=1) print(f"response: {response}") - except Exception as e: - if "Unsupported parameters passed: top_p" in str(e): + except Exception as e: + if "Unsupported parameters passed: top_p" in str(e): pass - else: - pytest.fail(f'An error occurred {e}') + else: + pytest.fail(f"An error occurred {e}") + # test_completion_invalid_param_cohere() + def test_completion_function_call_cohere(): - try: - response = completion(model="command-nightly", messages=messages, functions=["TEST-FUNCTION"]) - pytest.fail(f'An error occurred {e}') - except Exception as e: + try: + response = completion( + model="command-nightly", messages=messages, functions=["TEST-FUNCTION"] + ) + pytest.fail(f"An error occurred {e}") + except Exception as e: print(e) pass - + # test_completion_function_call_cohere() -def test_completion_function_call_openai(): - try: + +def test_completion_function_call_openai(): + try: messages = [{"role": "user", "content": "What is the weather like in Boston?"}] - response = completion(model="gpt-3.5-turbo", messages=messages, functions=[ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] + response = completion( + model="gpt-3.5-turbo", + messages=messages, + functions=[ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, } - }, - "required": ["location"] - } - } - ]) + ], + ) print(f"response: {response}") - except: + except: pass -# test_completion_function_call_openai() + +# test_completion_function_call_openai() + def test_completion_with_no_provider(): # test on empty @@ -125,6 +140,7 @@ def test_completion_with_no_provider(): print(f"error occurred: {e}") pass + # test_completion_with_no_provider() # # bad key # temp_key = os.environ.get("OPENAI_API_KEY") @@ -136,4 +152,4 @@ def test_completion_with_no_provider(): # except: # print(f"error occurred: {traceback.format_exc()}") # pass -# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5 \ No newline at end of file +# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5 diff --git a/litellm/tests/test_batch_completions.py b/litellm/tests/test_batch_completions.py index 8f149f3fa..55e3084b4 100644 --- a/litellm/tests/test_batch_completions.py +++ b/litellm/tests/test_batch_completions.py @@ -4,62 +4,78 @@ import sys, os import traceback import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path from openai import APITimeoutError as Timeout import litellm + litellm.num_retries = 0 -from litellm import batch_completion, batch_completion_models, completion, batch_completion_models_all_responses +from litellm import ( + batch_completion, + batch_completion_models, + completion, + batch_completion_models_all_responses, +) + # litellm.set_verbose=True + def test_batch_completions(): messages = [[{"role": "user", "content": "write a short poem"}] for _ in range(3)] model = "j2-mid" litellm.set_verbose = True try: result = batch_completion( - model=model, + model=model, messages=messages, max_tokens=10, temperature=0.2, - request_timeout=1 + request_timeout=1, ) print(result) print(len(result)) - assert(len(result)==3) + assert len(result) == 3 except Timeout as e: print(f"IN TIMEOUT") pass except Exception as e: pytest.fail(f"An error occurred: {e}") + + test_batch_completions() + def test_batch_completions_models(): try: result = batch_completion_models( - models=["gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-3.5-turbo"], - messages=[{"role": "user", "content": "Hey, how's it going"}] + models=["gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-3.5-turbo"], + messages=[{"role": "user", "content": "Hey, how's it going"}], ) print(result) except Timeout as e: pass except Exception as e: pytest.fail(f"An error occurred: {e}") + + # test_batch_completions_models() + def test_batch_completion_models_all_responses(): try: responses = batch_completion_models_all_responses( - models=["j2-light", "claude-instant-1.2"], + models=["j2-light", "claude-instant-1.2"], messages=[{"role": "user", "content": "write a poem"}], - max_tokens=10 + max_tokens=10, ) print(responses) - assert(len(responses) == 2) + assert len(responses) == 2 except Timeout as e: pass except Exception as e: pytest.fail(f"An error occurred: {e}") -# test_batch_completion_models_all_responses() + +# test_batch_completion_models_all_responses() diff --git a/litellm/tests/test_budget_manager.py b/litellm/tests/test_budget_manager.py index 5b4f4e6b3..83c9f2723 100644 --- a/litellm/tests/test_budget_manager.py +++ b/litellm/tests/test_budget_manager.py @@ -3,12 +3,12 @@ # import sys, os, json # import traceback -# import pytest +# import pytest # sys.path.insert( # 0, os.path.abspath("../..") # ) # Adds the parent directory to the system path -# import litellm +# import litellm # litellm.set_verbose = True # from litellm import completion, BudgetManager @@ -16,7 +16,7 @@ # ## Scenario 1: User budget enough to make call # def test_user_budget_enough(): -# try: +# try: # user = "1234" # # create a budget for a user # budget_manager.create_budget(total_budget=10, user=user, duration="daily") @@ -38,7 +38,7 @@ # ## Scenario 2: User budget not enough to make call # def test_user_budget_not_enough(): -# try: +# try: # user = "12345" # # create a budget for a user # budget_manager.create_budget(total_budget=0, user=user, duration="daily") @@ -60,7 +60,7 @@ # except: # pytest.fail(f"An error occurred") -# ## Scenario 3: Saving budget to client +# ## Scenario 3: Saving budget to client # def test_save_user_budget(): # try: # response = budget_manager.save_data() @@ -70,17 +70,17 @@ # except Exception as e: # pytest.fail(f"An error occurred: {str(e)}") -# test_save_user_budget() -# ## Scenario 4: Getting list of users +# test_save_user_budget() +# ## Scenario 4: Getting list of users # def test_get_users(): # try: # response = budget_manager.get_users() # print(response) # except: -# pytest.fail(f"An error occurred") +# pytest.fail(f"An error occurred") -# ## Scenario 5: Reset budget at the end of duration +# ## Scenario 5: Reset budget at the end of duration # def test_reset_on_duration(): # try: # # First, set a short duration budget for a user @@ -100,7 +100,7 @@ # # Now, we need to simulate the passing of time. Since we don't want our tests to actually take days, we're going # # to cheat a little -- we'll manually adjust the "created_at" time so it seems like a day has passed. -# # In a real-world testing scenario, we might instead use something like the `freezegun` library to mock the system time. +# # In a real-world testing scenario, we might instead use something like the `freezegun` library to mock the system time. # one_day_in_seconds = 24 * 60 * 60 # budget_manager.user_dict[user]["last_updated_at"] -= one_day_in_seconds @@ -108,11 +108,11 @@ # budget_manager.update_budget_all_users() # # Make sure the budget was actually reset -# assert budget_manager.get_current_cost(user) == 0, "Budget didn't reset after duration expired" +# assert budget_manager.get_current_cost(user) == 0, "Budget didn't reset after duration expired" # except Exception as e: # pytest.fail(f"An error occurred - {str(e)}") -# ## Scenario 6: passing in text: +# ## Scenario 6: passing in text: # def test_input_text_on_completion(): # try: # user = "12345" @@ -127,4 +127,4 @@ # except Exception as e: # pytest.fail(f"An error occurred - {str(e)}") -# test_input_text_on_completion() \ No newline at end of file +# test_input_text_on_completion() diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 24b7f37a8..081f71ebb 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -14,6 +14,7 @@ import litellm from litellm import embedding, completion from litellm.caching import Cache import random + # litellm.set_verbose=True messages = [{"role": "user", "content": "who is ishaan Github? "}] @@ -22,23 +23,30 @@ messages = [{"role": "user", "content": "who is ishaan Github? "}] import random import string + def generate_random_word(length=4): letters = string.ascii_lowercase - return ''.join(random.choice(letters) for _ in range(length)) + return "".join(random.choice(letters) for _ in range(length)) + messages = [{"role": "user", "content": "who is ishaan 5222"}] -def test_caching_v2(): # test in memory cache + + +def test_caching_v2(): # test in memory cache try: - litellm.set_verbose=True + litellm.set_verbose = True litellm.cache = Cache() response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) print(f"response1: {response1}") print(f"response2: {response2}") - litellm.cache = None # disable cache + litellm.cache = None # disable cache litellm.success_callback = [] litellm._async_success_callback = [] - if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']: + if ( + response2["choices"][0]["message"]["content"] + != response1["choices"][0]["message"]["content"] + ): print(f"response1: {response1}") print(f"response2: {response2}") pytest.fail(f"Error occurred:") @@ -46,12 +54,14 @@ def test_caching_v2(): # test in memory cache print(f"error occurred: {traceback.format_exc()}") pytest.fail(f"Error occurred: {e}") + # test_caching_v2() - def test_caching_with_models_v2(): - messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}] + messages = [ + {"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"} + ] litellm.cache = Cache() print("test2 for caching") response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) @@ -63,34 +73,51 @@ def test_caching_with_models_v2(): litellm.cache = None litellm.success_callback = [] litellm._async_success_callback = [] - if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']: + if ( + response3["choices"][0]["message"]["content"] + == response2["choices"][0]["message"]["content"] + ): # if models are different, it should not return cached response print(f"response2: {response2}") print(f"response3: {response3}") pytest.fail(f"Error occurred:") - if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: + if ( + response1["choices"][0]["message"]["content"] + != response2["choices"][0]["message"]["content"] + ): print(f"response1: {response1}") print(f"response2: {response2}") pytest.fail(f"Error occurred:") + + # test_caching_with_models_v2() -embedding_large_text = """ +embedding_large_text = ( + """ small text -""" * 5 +""" + * 5 +) + # # test_caching_with_models() def test_embedding_caching(): import time + litellm.cache = Cache() text_to_embed = [embedding_large_text] start_time = time.time() - embedding1 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True) + embedding1 = embedding( + model="text-embedding-ada-002", input=text_to_embed, caching=True + ) end_time = time.time() print(f"Embedding 1 response time: {end_time - start_time} seconds") time.sleep(1) start_time = time.time() - embedding2 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True) + embedding2 = embedding( + model="text-embedding-ada-002", input=text_to_embed, caching=True + ) end_time = time.time() print(f"embedding2: {embedding2}") print(f"Embedding 2 response time: {end_time - start_time} seconds") @@ -98,29 +125,30 @@ def test_embedding_caching(): litellm.cache = None litellm.success_callback = [] litellm._async_success_callback = [] - assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s - if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: + assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s + if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]: print(f"embedding1: {embedding1}") print(f"embedding2: {embedding2}") pytest.fail("Error occurred: Embedding caching failed") + # test_embedding_caching() def test_embedding_caching_azure(): print("Testing azure embedding caching") import time + litellm.cache = Cache() text_to_embed = [embedding_large_text] - api_key = os.environ['AZURE_API_KEY'] - api_base = os.environ['AZURE_API_BASE'] - api_version = os.environ['AZURE_API_VERSION'] - - os.environ['AZURE_API_VERSION'] = "" - os.environ['AZURE_API_BASE'] = "" - os.environ['AZURE_API_KEY'] = "" + api_key = os.environ["AZURE_API_KEY"] + api_base = os.environ["AZURE_API_BASE"] + api_version = os.environ["AZURE_API_VERSION"] + os.environ["AZURE_API_VERSION"] = "" + os.environ["AZURE_API_BASE"] = "" + os.environ["AZURE_API_KEY"] = "" start_time = time.time() print("AZURE CONFIGS") @@ -133,7 +161,7 @@ def test_embedding_caching_azure(): api_key=api_key, api_base=api_base, api_version=api_version, - caching=True + caching=True, ) end_time = time.time() print(f"Embedding 1 response time: {end_time - start_time} seconds") @@ -146,7 +174,7 @@ def test_embedding_caching_azure(): api_key=api_key, api_base=api_base, api_version=api_version, - caching=True + caching=True, ) end_time = time.time() print(f"Embedding 2 response time: {end_time - start_time} seconds") @@ -154,15 +182,16 @@ def test_embedding_caching_azure(): litellm.cache = None litellm.success_callback = [] litellm._async_success_callback = [] - assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s - if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: + assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s + if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]: print(f"embedding1: {embedding1}") print(f"embedding2: {embedding2}") pytest.fail("Error occurred: Embedding caching failed") - os.environ['AZURE_API_VERSION'] = api_version - os.environ['AZURE_API_BASE'] = api_base - os.environ['AZURE_API_KEY'] = api_key + os.environ["AZURE_API_VERSION"] = api_version + os.environ["AZURE_API_BASE"] = api_base + os.environ["AZURE_API_KEY"] = api_key + # test_embedding_caching_azure() @@ -170,13 +199,28 @@ def test_embedding_caching_azure(): def test_redis_cache_completion(): litellm.set_verbose = False - random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache - messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}] - litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) + random_number = random.randint( + 1, 100000 + ) # add a random number to ensure it's always adding / reading from cache + messages = [ + {"role": "user", "content": f"write a one sentence poem about: {random_number}"} + ] + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) print("test2 for caching") - response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20) - response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20) - response3 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5) + response1 = completion( + model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20 + ) + response2 = completion( + model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20 + ) + response3 = completion( + model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5 + ) response4 = completion(model="command-nightly", messages=messages, caching=True) print("\nresponse 1", response1) @@ -192,49 +236,88 @@ def test_redis_cache_completion(): 1 & 3 should be different, since input params are diff 1 & 4 should be diff, since models are diff """ - if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: # 1 and 2 should be the same + if ( + response1["choices"][0]["message"]["content"] + != response2["choices"][0]["message"]["content"] + ): # 1 and 2 should be the same # 1&2 have the exact same input params. This MUST Be a CACHE HIT print(f"response1: {response1}") print(f"response2: {response2}") pytest.fail(f"Error occurred:") - if response1['choices'][0]['message']['content'] == response3['choices'][0]['message']['content']: + if ( + response1["choices"][0]["message"]["content"] + == response3["choices"][0]["message"]["content"] + ): # if input params like seed, max_tokens are diff it should NOT be a cache hit print(f"response1: {response1}") print(f"response3: {response3}") - pytest.fail(f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:") - if response1['choices'][0]['message']['content'] == response4['choices'][0]['message']['content']: + pytest.fail( + f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:" + ) + if ( + response1["choices"][0]["message"]["content"] + == response4["choices"][0]["message"]["content"] + ): # if models are different, it should not return cached response print(f"response1: {response1}") print(f"response4: {response4}") pytest.fail(f"Error occurred:") + # test_redis_cache_completion() + def test_redis_cache_completion_stream(): try: litellm.success_callback = [] litellm._async_success_callback = [] litellm.callbacks = [] litellm.set_verbose = True - random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache - messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}] - litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) + random_number = random.randint( + 1, 100000 + ) # add a random number to ensure it's always adding / reading from cache + messages = [ + { + "role": "user", + "content": f"write a one sentence poem about: {random_number}", + } + ] + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) print("test for caching, streaming + completion") - response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True) + response1 = completion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=0.2, + stream=True, + ) response_1_content = "" for chunk in response1: print(chunk) response_1_content += chunk.choices[0].delta.content or "" print(response_1_content) time.sleep(0.5) - response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True) + response2 = completion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=0.2, + stream=True, + ) response_2_content = "" for chunk in response2: print(chunk) response_2_content += chunk.choices[0].delta.content or "" print("\nresponse 1", response_1_content) print("\nresponse 2", response_2_content) - assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" + assert ( + response_1_content == response_2_content + ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" litellm.success_callback = [] litellm.cache = None litellm.success_callback = [] @@ -247,99 +330,171 @@ def test_redis_cache_completion_stream(): 1 & 2 should be exactly the same """ + + # test_redis_cache_completion_stream() def test_redis_cache_acompletion_stream(): import asyncio + try: litellm.set_verbose = True random_word = generate_random_word() - messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}] - litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) + messages = [ + { + "role": "user", + "content": f"write a one sentence poem about: {random_word}", + } + ] + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) print("test for caching, streaming + completion") response_1_content = "" response_2_content = "" async def call1(): - nonlocal response_1_content - response1 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True) + nonlocal response_1_content + response1 = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) async for chunk in response1: print(chunk) response_1_content += chunk.choices[0].delta.content or "" print(response_1_content) + asyncio.run(call1()) time.sleep(0.5) print("\n\n Response 1 content: ", response_1_content, "\n\n") async def call2(): nonlocal response_2_content - response2 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True) + response2 = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) async for chunk in response2: print(chunk) response_2_content += chunk.choices[0].delta.content or "" print(response_2_content) + asyncio.run(call2()) print("\nresponse 1", response_1_content) print("\nresponse 2", response_2_content) - assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" + assert ( + response_1_content == response_2_content + ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" litellm.cache = None litellm.success_callback = [] litellm._async_success_callback = [] except Exception as e: print(e) raise e + + # test_redis_cache_acompletion_stream() + def test_redis_cache_acompletion_stream_bedrock(): import asyncio + try: litellm.set_verbose = True random_word = generate_random_word() - messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}] - litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) + messages = [ + { + "role": "user", + "content": f"write a one sentence poem about: {random_word}", + } + ] + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) print("test for caching, streaming + completion") response_1_content = "" response_2_content = "" async def call1(): - nonlocal response_1_content - response1 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True) + nonlocal response_1_content + response1 = await litellm.acompletion( + model="bedrock/anthropic.claude-v1", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) async for chunk in response1: print(chunk) response_1_content += chunk.choices[0].delta.content or "" print(response_1_content) + asyncio.run(call1()) time.sleep(0.5) print("\n\n Response 1 content: ", response_1_content, "\n\n") async def call2(): nonlocal response_2_content - response2 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True) + response2 = await litellm.acompletion( + model="bedrock/anthropic.claude-v1", + messages=messages, + max_tokens=40, + temperature=1, + stream=True, + ) async for chunk in response2: print(chunk) response_2_content += chunk.choices[0].delta.content or "" print(response_2_content) + asyncio.run(call2()) print("\nresponse 1", response_1_content) print("\nresponse 2", response_2_content) - assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" + assert ( + response_1_content == response_2_content + ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" litellm.cache = None litellm.success_callback = [] litellm._async_success_callback = [] except Exception as e: print(e) raise e + + # test_redis_cache_acompletion_stream_bedrock() # redis cache with custom keys def custom_get_cache_key(*args, **kwargs): - # return key to use for your cache: - key = kwargs.get("model", "") + str(kwargs.get("messages", "")) + str(kwargs.get("temperature", "")) + str(kwargs.get("logit_bias", "")) + # return key to use for your cache: + key = ( + kwargs.get("model", "") + + str(kwargs.get("messages", "")) + + str(kwargs.get("temperature", "")) + + str(kwargs.get("logit_bias", "")) + ) return key + def test_custom_redis_cache_with_key(): messages = [{"role": "user", "content": "write a one line story"}] - litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) litellm.cache.get_cache_key = custom_get_cache_key local_cache = {} @@ -356,54 +511,72 @@ def test_custom_redis_cache_with_key(): # patch this redis cache get and set call - response1 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3) - response2 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3) - response3 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=False, num_retries=3) + response1 = completion( + model="gpt-3.5-turbo", + messages=messages, + temperature=1, + caching=True, + num_retries=3, + ) + response2 = completion( + model="gpt-3.5-turbo", + messages=messages, + temperature=1, + caching=True, + num_retries=3, + ) + response3 = completion( + model="gpt-3.5-turbo", + messages=messages, + temperature=1, + caching=False, + num_retries=3, + ) print(f"response1: {response1}") print(f"response2: {response2}") print(f"response3: {response3}") - if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']: + if ( + response3["choices"][0]["message"]["content"] + == response2["choices"][0]["message"]["content"] + ): pytest.fail(f"Error occurred:") litellm.cache = None litellm.success_callback = [] litellm._async_success_callback = [] + # test_custom_redis_cache_with_key() + def test_cache_override(): # test if we can override the cache, when `caching=False` but litellm.cache = Cache() is set - # in this case it should not return cached responses + # in this case it should not return cached responses litellm.cache = Cache() print("Testing cache override") - litellm.set_verbose=True + litellm.set_verbose = True # test embedding response1 = embedding( - model = "text-embedding-ada-002", - input=[ - "hello who are you" - ], - caching = False + model="text-embedding-ada-002", input=["hello who are you"], caching=False ) - start_time = time.time() response2 = embedding( - model = "text-embedding-ada-002", - input=[ - "hello who are you" - ], - caching = False + model="text-embedding-ada-002", input=["hello who are you"], caching=False ) end_time = time.time() print(f"Embedding 2 response time: {end_time - start_time} seconds") - assert end_time - start_time > 0.1 # ensure 2nd response comes in over 0.1s. This should not be cached. -# test_cache_override() + assert ( + end_time - start_time > 0.1 + ) # ensure 2nd response comes in over 0.1s. This should not be cached. + + +# test_cache_override() def test_custom_redis_cache_params(): @@ -411,17 +584,17 @@ def test_custom_redis_cache_params(): try: litellm.cache = Cache( type="redis", - host=os.environ['REDIS_HOST'], - port=os.environ['REDIS_PORT'], - password=os.environ['REDIS_PASSWORD'], - db = 0, + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + db=0, ssl=True, ssl_certfile="./redis_user.crt", ssl_keyfile="./redis_user_private.key", ssl_ca_certs="./redis_ca.pem", ) - print(litellm.cache.cache.redis_client) + print(litellm.cache.cache.redis_client) litellm.cache = None litellm.success_callback = [] litellm._async_success_callback = [] @@ -431,58 +604,126 @@ def test_custom_redis_cache_params(): def test_get_cache_key(): from litellm.caching import Cache + try: print("Testing get_cache_key") cache_instance = Cache() - cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}} + cache_key = cache_instance.get_cache_key( + **{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "write a one sentence poem about: 7510"} + ], + "max_tokens": 40, + "temperature": 0.2, + "stream": True, + "litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b", + "litellm_logging_obj": {}, + } ) - cache_key_2 = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}} + cache_key_2 = cache_instance.get_cache_key( + **{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "write a one sentence poem about: 7510"} + ], + "max_tokens": 40, + "temperature": 0.2, + "stream": True, + "litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b", + "litellm_logging_obj": {}, + } ) - assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40" - assert cache_key == cache_key_2, f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs" + assert ( + cache_key + == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40" + ) + assert ( + cache_key == cache_key_2 + ), f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs" embedding_cache_key = cache_instance.get_cache_key( - **{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/', - 'api_key': '', 'api_version': '2023-07-01-preview', - 'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'], - 'caching': True, - 'client': "" + **{ + "model": "azure/azure-embedding-model", + "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/", + "api_key": "", + "api_version": "2023-07-01-preview", + "timeout": None, + "max_retries": 0, + "input": ["hi who is ishaan"], + "caching": True, + "client": "", } ) print(embedding_cache_key) - assert embedding_cache_key == "model: azure/azure-embedding-modelinput: ['hi who is ishaan']", f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs" + assert ( + embedding_cache_key + == "model: azure/azure-embedding-modelinput: ['hi who is ishaan']" + ), f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs" # Proxy - embedding cache, test if embedding key, gets model_group and not model embedding_cache_key_2 = cache_instance.get_cache_key( - **{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/', - 'api_key': '', 'api_version': '2023-07-01-preview', - 'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'], - 'caching': True, - 'client': "", - 'proxy_server_request': {'url': 'http://0.0.0.0:8000/embeddings', - 'method': 'POST', - 'headers': - {'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', - 'content-length': '80'}, - 'body': {'model': 'azure-embedding-model', 'input': ['hi who is ishaan']}}, - 'user': None, - 'metadata': {'user_api_key': None, - 'headers': {'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', 'content-length': '80'}, - 'model_group': 'EMBEDDING_MODEL_GROUP', - 'deployment': 'azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview'}, - 'model_info': {'mode': 'embedding', 'base_model': 'text-embedding-ada-002', 'id': '20b2b515-f151-4dd5-a74f-2231e2f54e29'}, - 'litellm_call_id': '2642e009-b3cd-443d-b5dd-bb7d56123b0e', 'litellm_logging_obj': ''} + **{ + "model": "azure/azure-embedding-model", + "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/", + "api_key": "", + "api_version": "2023-07-01-preview", + "timeout": None, + "max_retries": 0, + "input": ["hi who is ishaan"], + "caching": True, + "client": "", + "proxy_server_request": { + "url": "http://0.0.0.0:8000/embeddings", + "method": "POST", + "headers": { + "host": "0.0.0.0:8000", + "user-agent": "curl/7.88.1", + "accept": "*/*", + "content-type": "application/json", + "content-length": "80", + }, + "body": { + "model": "azure-embedding-model", + "input": ["hi who is ishaan"], + }, + }, + "user": None, + "metadata": { + "user_api_key": None, + "headers": { + "host": "0.0.0.0:8000", + "user-agent": "curl/7.88.1", + "accept": "*/*", + "content-type": "application/json", + "content-length": "80", + }, + "model_group": "EMBEDDING_MODEL_GROUP", + "deployment": "azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview", + }, + "model_info": { + "mode": "embedding", + "base_model": "text-embedding-ada-002", + "id": "20b2b515-f151-4dd5-a74f-2231e2f54e29", + }, + "litellm_call_id": "2642e009-b3cd-443d-b5dd-bb7d56123b0e", + "litellm_logging_obj": "", + } ) print(embedding_cache_key_2) - assert embedding_cache_key_2 == "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']" + assert ( + embedding_cache_key_2 + == "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']" + ) print("passed!") except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred:", e) + test_get_cache_key() # test_custom_redis_cache_params() @@ -581,4 +822,4 @@ test_get_cache_key() # assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content'] # time.sleep(2) # assert cache.get_cache(cache_key="test_key") is None -# # test_in_memory_cache_with_ttl() \ No newline at end of file +# # test_in_memory_cache_with_ttl() diff --git a/litellm/tests/test_caching_ssl.py b/litellm/tests/test_caching_ssl.py index 839e37ea4..84ece8310 100644 --- a/litellm/tests/test_caching_ssl.py +++ b/litellm/tests/test_caching_ssl.py @@ -1,5 +1,5 @@ #### What this tests #### -# This tests using caching w/ litellm which requires SSL=True +# This tests using caching w/ litellm which requires SSL=True import sys, os import time @@ -18,15 +18,26 @@ from litellm import embedding, completion, Router from litellm.caching import Cache messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}] -def test_caching_v2(): # test in memory cache + + +def test_caching_v2(): # test in memory cache try: - litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2") + litellm.cache = Cache( + type="redis", + host="os.environ/REDIS_HOST_2", + port="os.environ/REDIS_PORT_2", + password="os.environ/REDIS_PASSWORD_2", + ssl="os.environ/REDIS_SSL_2", + ) response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) print(f"response1: {response1}") print(f"response2: {response2}") - litellm.cache = None # disable cache - if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']: + litellm.cache = None # disable cache + if ( + response2["choices"][0]["message"]["content"] + != response1["choices"][0]["message"]["content"] + ): print(f"response1: {response1}") print(f"response2: {response2}") raise Exception() @@ -34,41 +45,57 @@ def test_caching_v2(): # test in memory cache print(f"error occurred: {traceback.format_exc()}") pytest.fail(f"Error occurred: {e}") + # test_caching_v2() def test_caching_router(): """ - Test scenario where litellm.cache is set but kwargs("caching") is not. This should still return a cache hit. + Test scenario where litellm.cache is set but kwargs("caching") is not. This should still return a cache hit. """ - try: + try: model_list = [ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - } - ] - litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2") - router = Router(model_list=model_list, - routing_strategy="simple-shuffle", - set_verbose=False, - num_retries=1) # type: ignore + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + } + ] + litellm.cache = Cache( + type="redis", + host="os.environ/REDIS_HOST_2", + port="os.environ/REDIS_PORT_2", + password="os.environ/REDIS_PASSWORD_2", + ssl="os.environ/REDIS_SSL_2", + ) + router = Router( + model_list=model_list, + routing_strategy="simple-shuffle", + set_verbose=False, + num_retries=1, + ) # type: ignore response1 = completion(model="gpt-3.5-turbo", messages=messages) response2 = completion(model="gpt-3.5-turbo", messages=messages) - if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']: + if ( + response2["choices"][0]["message"]["content"] + != response1["choices"][0]["message"]["content"] + ): print(f"response1: {response1}") print(f"response2: {response2}") - litellm.cache = None # disable cache - assert response2['choices'][0]['message']['content'] == response1['choices'][0]['message']['content'] + litellm.cache = None # disable cache + assert ( + response2["choices"][0]["message"]["content"] + == response1["choices"][0]["message"]["content"] + ) except Exception as e: print(f"error occurred: {traceback.format_exc()}") pytest.fail(f"Error occurred: {e}") -# test_caching_router() \ No newline at end of file + +# test_caching_router() diff --git a/litellm/tests/test_class.py b/litellm/tests/test_class.py index aa7f86242..3520d870d 100644 --- a/litellm/tests/test_class.py +++ b/litellm/tests/test_class.py @@ -8,7 +8,7 @@ # 0, os.path.abspath("../..") # ) # Adds the parent directory to the system path # import litellm -# import asyncio +# import asyncio # litellm.set_verbose = True # from litellm import Router @@ -18,9 +18,9 @@ # # This enables response_model keyword # # # from client.chat.completions.create # # client = instructor.patch(Router(model_list=[{ -# # "model_name": "gpt-3.5-turbo", # openai model name -# # "litellm_params": { # params for litellm completion/embedding call -# # "model": "azure/chatgpt-v-2", +# # "model_name": "gpt-3.5-turbo", # openai model name +# # "litellm_params": { # params for litellm completion/embedding call +# # "model": "azure/chatgpt-v-2", # # "api_key": os.getenv("AZURE_API_KEY"), # # "api_version": os.getenv("AZURE_API_VERSION"), # # "api_base": os.getenv("AZURE_API_BASE") @@ -49,9 +49,9 @@ # from openai import AsyncOpenAI # aclient = instructor.apatch(Router(model_list=[{ -# "model_name": "gpt-3.5-turbo", # openai model name -# "litellm_params": { # params for litellm completion/embedding call -# "model": "azure/chatgpt-v-2", +# "model_name": "gpt-3.5-turbo", # openai model name +# "litellm_params": { # params for litellm completion/embedding call +# "model": "azure/chatgpt-v-2", # "api_key": os.getenv("AZURE_API_KEY"), # "api_version": os.getenv("AZURE_API_VERSION"), # "api_base": os.getenv("AZURE_API_BASE") @@ -71,4 +71,4 @@ # ) # print(f"model: {model}") -# asyncio.run(main()) \ No newline at end of file +# asyncio.run(main()) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 2f4195717..4408abb53 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -7,20 +7,23 @@ import os, io sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system path import pytest import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError + # litellm.num_retries = 3 litellm.cache = None -litellm.success_callback = [] +litellm.success_callback = [] user_message = "Write a short poem about the sky" messages = [{"content": user_message, "role": "user"}] + def logger_fn(user_model_dict): print(f"user_model_dict: {user_model_dict}") + @pytest.fixture(autouse=True) def reset_callbacks(): print("\npytest fixture - resetting callbacks") @@ -29,6 +32,7 @@ def reset_callbacks(): litellm.failure_callback = [] litellm.callbacks = [] + def test_completion_custom_provider_model_name(): try: litellm.cache = None @@ -39,7 +43,7 @@ def test_completion_custom_provider_model_name(): ) # Add any assertions here to check the response print(response) - print(response['choices'][0]['finish_reason']) + print(response["choices"][0]["finish_reason"]) except Exception as e: pytest.fail(f"Error occurred: {e}") @@ -51,52 +55,19 @@ def test_completion_claude(): litellm.set_verbose = True litellm.cache = None litellm.AnthropicConfig(max_tokens_to_sample=200, metadata={"user_id": "1224"}) - messages = [{"role": "system", "content": """You are an upbeat, enthusiastic personal fitness coach named Sam. Sam is passionate about helping clients get fit and lead healthier lifestyles. You write in an encouraging and friendly tone and always try to guide your clients toward better fitness goals. If the user asks you something unrelated to fitness, either bring the topic back to fitness, or say that you cannot answer."""},{"content": user_message, "role": "user"}] + messages = [ + { + "role": "system", + "content": """You are an upbeat, enthusiastic personal fitness coach named Sam. Sam is passionate about helping clients get fit and lead healthier lifestyles. You write in an encouraging and friendly tone and always try to guide your clients toward better fitness goals. If the user asks you something unrelated to fitness, either bring the topic back to fitness, or say that you cannot answer.""", + }, + {"content": user_message, "role": "user"}, + ] try: # test without max tokens response = completion( - model="claude-instant-1", messages=messages, request_timeout=10, - ) - # Add any assertions here to check the response - print(response) - print(response.usage) - print(response.usage.completion_tokens) - print(response["usage"]["completion_tokens"]) - # print("new cost tracking") - except Exception as e: - pytest.fail(f"Error occurred: {e}") - -# test_completion_claude() - -def test_completion_mistral_api(): - try: - litellm.set_verbose=True - response = completion( - model="mistral/mistral-tiny", - messages=[ - { - "role": "user", - "content": "Hey, how's it going?", - } - ], - safe_mode = True - ) - # Add any assertions here to check the response - print(response) - except Exception as e: - pytest.fail(f"Error occurred: {e}") -# test_completion_mistral_api() - -def test_completion_claude2_1(): - try: - print("claude2.1 test request") - messages=[{'role': 'system', 'content': 'Your goal is generate a joke on the topic user gives'}, {'role': 'assistant', 'content': 'Hi, how can i assist you today?'}, {'role': 'user', 'content': 'Generate a 3 liner joke for me'}] - # test without max tokens - response = completion( - model="claude-2.1", - messages=messages, + model="claude-instant-1", + messages=messages, request_timeout=10, - max_tokens=10 ) # Add any assertions here to check the response print(response) @@ -106,6 +77,58 @@ def test_completion_claude2_1(): # print("new cost tracking") except Exception as e: pytest.fail(f"Error occurred: {e}") + + +# test_completion_claude() + + +def test_completion_mistral_api(): + try: + litellm.set_verbose = True + response = completion( + model="mistral/mistral-tiny", + messages=[ + { + "role": "user", + "content": "Hey, how's it going?", + } + ], + safe_mode=True, + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +# test_completion_mistral_api() + + +def test_completion_claude2_1(): + try: + print("claude2.1 test request") + messages = [ + { + "role": "system", + "content": "Your goal is generate a joke on the topic user gives", + }, + {"role": "assistant", "content": "Hi, how can i assist you today?"}, + {"role": "user", "content": "Generate a 3 liner joke for me"}, + ] + # test without max tokens + response = completion( + model="claude-2.1", messages=messages, request_timeout=10, max_tokens=10 + ) + # Add any assertions here to check the response + print(response) + print(response.usage) + print(response.usage.completion_tokens) + print(response["usage"]["completion_tokens"]) + # print("new cost tracking") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # test_completion_claude2_1() # def test_completion_oobabooga(): @@ -144,10 +167,12 @@ def test_completion_claude2_1(): # test_completion_aleph_alpha_control_models() import openai + + def test_completion_gpt4_turbo(): try: response = completion( - model="gpt-4-1106-preview", + model="gpt-4-1106-preview", messages=messages, max_tokens=10, ) @@ -157,29 +182,29 @@ def test_completion_gpt4_turbo(): pass except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_gpt4_turbo() + @pytest.mark.skip(reason="this test is flaky") def test_completion_gpt4_vision(): try: - litellm.set_verbose=True + litellm.set_verbose = True response = completion( - model="gpt-4-vision-preview", + model="gpt-4-vision-preview", messages=[ { "role": "user", "content": [ - { - "type": "text", - "text": "Whats in this image?" - }, - { - "type": "image_url", - "image_url": { - "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" - } - } - ] + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + }, + }, + ], } ], ) @@ -189,53 +214,60 @@ def test_completion_gpt4_vision(): pass except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_gpt4_vision() + @pytest.mark.skip(reason="this test is flaky") def test_completion_perplexity_api(): try: # litellm.set_verbose=True - messages=[{ - "role": "system", - "content": "You're a good bot" - },{ - "role": "user", - "content": "Hey", - },{ - "role": "user", - "content": "Hey", - }] + messages = [ + {"role": "system", "content": "You're a good bot"}, + { + "role": "user", + "content": "Hey", + }, + { + "role": "user", + "content": "Hey", + }, + ] response = completion( - model="mistral-7b-instruct", + model="mistral-7b-instruct", messages=messages, - api_base="https://api.perplexity.ai") + api_base="https://api.perplexity.ai", + ) print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_perplexity_api() + @pytest.mark.skip(reason="this test is flaky") def test_completion_perplexity_api_2(): try: # litellm.set_verbose=True - messages=[{ - "role": "system", - "content": "You're a good bot" - },{ - "role": "user", - "content": "Hey", - },{ - "role": "user", - "content": "Hey", - }] - response = completion( - model="perplexity/mistral-7b-instruct", - messages=messages - ) + messages = [ + {"role": "system", "content": "You're a good bot"}, + { + "role": "user", + "content": "Hey", + }, + { + "role": "user", + "content": "Hey", + }, + ] + response = completion(model="perplexity/mistral-7b-instruct", messages=messages) print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_perplexity_api_2() # commenting out as this is a flaky test on circle ci @@ -269,6 +301,8 @@ HF Tests we should pass - Free Inference API - Deployed Endpoint """ + + ##################################################### ##################################################### # Test util to sort models to TGI, conv, None @@ -276,28 +310,29 @@ def test_get_hf_task_for_model(): model = "glaiveai/glaive-coder-7b" model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) print(f"model:{model}, model type: {model_type}") - assert(model_type == "text-generation-inference") + assert model_type == "text-generation-inference" model = "meta-llama/Llama-2-7b-hf" model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) print(f"model:{model}, model type: {model_type}") - assert(model_type == "text-generation-inference") + assert model_type == "text-generation-inference" model = "facebook/blenderbot-400M-distill" model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) print(f"model:{model}, model type: {model_type}") - assert(model_type == "conversational") + assert model_type == "conversational" model = "facebook/blenderbot-3B" model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) print(f"model:{model}, model type: {model_type}") - assert(model_type == "conversational") + assert model_type == "conversational" # neither Conv or None model = "roneneldan/TinyStories-3M" model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) print(f"model:{model}, model type: {model_type}") - assert(model_type == None) + assert model_type == None + # test_get_hf_task_for_model() # litellm.set_verbose=False @@ -308,13 +343,15 @@ def hf_test_completion_tgi(): # litellm.set_verbose=True try: response = completion( - model = 'huggingface/HuggingFaceH4/zephyr-7b-beta', - messages = [{ "content": "Hello, how are you?","role": "user"}], + model="huggingface/HuggingFaceH4/zephyr-7b-beta", + messages=[{"content": "Hello, how are you?", "role": "user"}], ) # Add any assertions here to check the response print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # hf_test_completion_tgi() # ################### Hugging Face Conversational models ######################## @@ -337,7 +374,7 @@ def hf_test_completion_tgi(): # user_message = "My name is Merve and my favorite" # messages = [{ "content": user_message,"role": "user"}] # response = completion( -# model="huggingface/roneneldan/TinyStories-3M", +# model="huggingface/roneneldan/TinyStories-3M", # messages=messages, # api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud", # ) @@ -397,7 +434,7 @@ def hf_test_completion_tgi(): # user_message = "My name is Merve and my favorite" # messages = [{ "content": user_message,"role": "user"}] # response = completion( -# model="huggingface/roneneldan/TinyStories-3M", +# model="huggingface/roneneldan/TinyStories-3M", # messages=messages, # api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud", @@ -410,13 +447,12 @@ def hf_test_completion_tgi(): # hf_test_error_logs() -def test_completion_cohere(): # commenting for now as the cohere endpoint is being flaky + +def test_completion_cohere(): # commenting for now as the cohere endpoint is being flaky try: litellm.CohereConfig(max_tokens=1000, stop_sequences=["a"]) response = completion( - model="command-nightly", - messages=messages, - logger_fn=logger_fn + model="command-nightly", messages=messages, logger_fn=logger_fn ) # Add any assertions here to check the response print(response) @@ -429,24 +465,24 @@ def test_completion_cohere(): # commenting for now as the cohere endpoint is bei except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_cohere() + +# test_completion_cohere() def test_completion_openai(): try: - litellm.set_verbose=True + litellm.set_verbose = True print(f"api key: {os.environ['OPENAI_API_KEY']}") - litellm.api_key = os.environ['OPENAI_API_KEY'] + litellm.api_key = os.environ["OPENAI_API_KEY"] response = completion( - model="gpt-3.5-turbo", - messages=messages, - max_tokens=10, + model="gpt-3.5-turbo", + messages=messages, + max_tokens=10, request_timeout=1, - metadata = {"hi": "bye"} + metadata={"hi": "bye"}, ) print("This is the response object\n", response) - response_str = response["choices"][0]["message"]["content"] response_str_2 = response.choices[0].message.content @@ -461,8 +497,11 @@ def test_completion_openai(): pass except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_openai() + def test_completion_text_openai(): try: # litellm.set_verbose = True @@ -471,12 +510,16 @@ def test_completion_text_openai(): except Exception as e: print(e) pytest.fail(f"Error occurred: {e}") + + # test_completion_text_openai() + def custom_callback( - kwargs, # kwargs to completion - completion_response, # response from completion - start_time, end_time # start/end time + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, + end_time, # start/end time ): # Your custom code here try: @@ -505,48 +548,49 @@ def custom_callback( except Exception as e: pytest.fail(f"Error occurred: {e}") + def test_completion_openai_with_optional_params(): # [Proxy PROD TEST] WARNING: DO NOT DELETE THIS TEST # assert that `user` gets passed to the completion call # Note: This tests that we actually send the optional params to the completion call - # We use custom callbacks to test this + # We use custom callbacks to test this try: litellm.set_verbose = True litellm.success_callback = [custom_callback] response = completion( model="gpt-3.5-turbo-1106", messages=[ - { - "role": "user", - "content": "respond in valid, json - what is the day" - } + {"role": "user", "content": "respond in valid, json - what is the day"} ], temperature=0.5, top_p=0.1, seed=12, - response_format={ "type": "json_object" }, + response_format={"type": "json_object"}, logit_bias=None, - user = "ishaans app" + user="ishaans app", ) # Add any assertions here to check the response print(response) - litellm.success_callback = [] # unset callbacks + litellm.success_callback = [] # unset callbacks except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_openai_with_optional_params() + def test_completion_openai_litellm_key(): try: litellm.set_verbose = True litellm.num_retries = 0 - litellm.api_key = os.environ['OPENAI_API_KEY'] + litellm.api_key = os.environ["OPENAI_API_KEY"] # ensure key is set to None in .env and in openai.api_key - os.environ['OPENAI_API_KEY'] = "" + os.environ["OPENAI_API_KEY"] = "" import openai + openai.api_key = "" ########################################################## @@ -562,33 +606,37 @@ def test_completion_openai_litellm_key(): print(response) ###### reset environ key - os.environ['OPENAI_API_KEY'] = litellm.api_key + os.environ["OPENAI_API_KEY"] = litellm.api_key ##### unset litellm var litellm.api_key = None - except Timeout as e: + except Timeout as e: pass except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_openai_litellm_key() -def test_completion_ollama_hosted(): + +def test_completion_ollama_hosted(): try: litellm.set_verbose = True response = completion( model="ollama/phi", messages=messages, max_tokens=10, - api_base="https://test-ollama-endpoint.onrender.com" + api_base="https://test-ollama-endpoint.onrender.com", ) # Add any assertions here to check the response print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_ollama_hosted() + def test_completion_openrouter1(): try: response = completion( @@ -600,7 +648,10 @@ def test_completion_openrouter1(): print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_openrouter1() + + +# test_completion_openrouter1() + def test_completion_hf_model_no_provider(): try: @@ -615,8 +666,10 @@ def test_completion_hf_model_no_provider(): except Exception as e: pass + # test_completion_hf_model_no_provider() + def test_completion_anyscale_with_functions(): function1 = [ { @@ -638,24 +691,29 @@ def test_completion_anyscale_with_functions(): try: messages = [{"role": "user", "content": "What is the weather like in Boston?"}] response = completion( - model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, functions=function1 + model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", + messages=messages, + functions=function1, ) # Add any assertions here to check the response print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_anyscale_with_functions() + def test_completion_azure_key_completion_arg(): - # this tests if we can pass api_key to completion, when it's not in the env - # DO NOT REMOVE THIS TEST. No MATTER WHAT Happens. - # If you want to remove it, speak to Ishaan! + # this tests if we can pass api_key to completion, when it's not in the env + # DO NOT REMOVE THIS TEST. No MATTER WHAT Happens. + # If you want to remove it, speak to Ishaan! # Ishaan will be very disappointed if this test is removed -> this is a standard way to pass api_key + the router + proxy use this old_key = os.environ["AZURE_API_KEY"] os.environ.pop("AZURE_API_KEY", None) try: print("azure gpt-3.5 test\n\n") - litellm.set_verbose=True + litellm.set_verbose = True ## Test azure call response = completion( model="azure/chatgpt-v-2", @@ -668,30 +726,32 @@ def test_completion_azure_key_completion_arg(): except Exception as e: os.environ["AZURE_API_KEY"] = old_key pytest.fail(f"Error occurred: {e}") + + # test_completion_azure_key_completion_arg() async def test_re_use_azure_async_client(): try: print("azure gpt-3.5 ASYNC with clie nttest\n\n") - litellm.set_verbose=True + litellm.set_verbose = True import openai + client = openai.AsyncAzureOpenAI( - azure_endpoint=os.environ['AZURE_API_BASE'], - api_key=os.environ["AZURE_API_KEY"], - api_version="2023-07-01-preview", + azure_endpoint=os.environ["AZURE_API_BASE"], + api_key=os.environ["AZURE_API_KEY"], + api_version="2023-07-01-preview", ) ## Test azure call for _ in range(3): response = await litellm.acompletion( - model="azure/chatgpt-v-2", - messages=messages, - client=client + model="azure/chatgpt-v-2", messages=messages, client=client ) print(f"response: {response}") except Exception as e: pytest.fail("got Exception", e) + # import asyncio # asyncio.run( # test_re_use_azure_async_client() @@ -701,32 +761,34 @@ async def test_re_use_azure_async_client(): def test_re_use_openaiClient(): try: print("gpt-3.5 with client test\n\n") - litellm.set_verbose=True + litellm.set_verbose = True import openai + client = openai.OpenAI( - api_key=os.environ["OPENAI_API_KEY"], + api_key=os.environ["OPENAI_API_KEY"], ) ## Test OpenAI call for _ in range(2): response = litellm.completion( - model="gpt-3.5-turbo", - messages=messages, - client=client + model="gpt-3.5-turbo", messages=messages, client=client ) print(f"response: {response}") except Exception as e: pytest.fail("got Exception", e) + + # test_re_use_openaiClient() + def test_completion_azure(): try: print("azure gpt-3.5 test\n\n") - litellm.set_verbose=False + litellm.set_verbose = False ## Test azure call response = completion( model="azure/chatgpt-v-2", messages=messages, - api_key="os.environ/AZURE_API_KEY" + api_key="os.environ/AZURE_API_KEY", ) print(f"response: {response}") ## Test azure flag for backwards compatibility @@ -740,24 +802,27 @@ def test_completion_azure(): print(response) cost = completion_cost(completion_response=response) - assert cost > 0.0 + assert cost > 0.0 print("Cost for azure completion request", cost) except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_azure() + +# test_completion_azure() + def test_azure_openai_ad_token(): # this tests if the azure ad token is set in the request header # the request can fail since azure ad tokens expire after 30 mins, but the header MUST have the azure ad token # we use litellm.input_callbacks for this test - def tester( - kwargs, # kwargs to completion + def tester( + kwargs, # kwargs to completion ): print(kwargs["additional_args"]) - if kwargs["additional_args"]["headers"]["Authorization"] != 'Bearer gm': + if kwargs["additional_args"]["headers"]["Authorization"] != "Bearer gm": pytest.fail("AZURE AD TOKEN Passed but not set in request header") return + litellm.input_callback = [tester] try: response = litellm.completion( @@ -768,7 +833,7 @@ def test_azure_openai_ad_token(): "content": "what is your name", }, ], - azure_ad_token="gm" + azure_ad_token="gm", ) print("azure ad token respoonse\n") print(response) @@ -776,6 +841,8 @@ def test_azure_openai_ad_token(): except: litellm.input_callback = [] pass + + # test_azure_openai_ad_token() @@ -784,7 +851,7 @@ def test_completion_azure2(): # test if we can pass api_base, api_version and api_key in compleition() try: print("azure gpt-3.5 test\n\n") - litellm.set_verbose=False + litellm.set_verbose = False api_base = os.environ["AZURE_API_BASE"] api_key = os.environ["AZURE_API_KEY"] api_version = os.environ["AZURE_API_VERSION"] @@ -793,14 +860,13 @@ def test_completion_azure2(): os.environ["AZURE_API_VERSION"] = "" os.environ["AZURE_API_KEY"] = "" - ## Test azure call response = completion( model="azure/chatgpt-v-2", messages=messages, - api_base = api_base, - api_key = api_key, - api_version = api_version, + api_base=api_base, + api_key=api_key, + api_version=api_version, max_tokens=10, ) @@ -814,13 +880,15 @@ def test_completion_azure2(): except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_azure2() + def test_completion_azure3(): # test if we can pass api_base, api_version and api_key in compleition() try: print("azure gpt-3.5 test\n\n") - litellm.set_verbose=True + litellm.set_verbose = True litellm.api_base = os.environ["AZURE_API_BASE"] litellm.api_key = os.environ["AZURE_API_KEY"] litellm.api_version = os.environ["AZURE_API_VERSION"] @@ -829,7 +897,6 @@ def test_completion_azure3(): os.environ["AZURE_API_VERSION"] = "" os.environ["AZURE_API_KEY"] = "" - ## Test azure call response = completion( model="azure/chatgpt-v-2", @@ -846,30 +913,32 @@ def test_completion_azure3(): except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_azure3() -# new azure test for using litellm. vars, + +# new azure test for using litellm. vars, # use the following vars in this test and make an azure_api_call -# litellm.api_type = self.azure_api_type -# litellm.api_base = self.azure_api_base -# litellm.api_version = self.azure_api_version -# litellm.api_key = self.api_key +# litellm.api_type = self.azure_api_type +# litellm.api_base = self.azure_api_base +# litellm.api_version = self.azure_api_version +# litellm.api_key = self.api_key def test_completion_azure_with_litellm_key(): try: print("azure gpt-3.5 test\n\n") import openai - #### set litellm vars litellm.api_type = "azure" - litellm.api_base = os.environ['AZURE_API_BASE'] - litellm.api_version = os.environ['AZURE_API_VERSION'] - litellm.api_key = os.environ['AZURE_API_KEY'] + litellm.api_base = os.environ["AZURE_API_BASE"] + litellm.api_version = os.environ["AZURE_API_VERSION"] + litellm.api_key = os.environ["AZURE_API_KEY"] ######### UNSET ENV VARs for this ################ - os.environ['AZURE_API_BASE'] = "" - os.environ['AZURE_API_VERSION'] = "" - os.environ['AZURE_API_KEY'] = "" + os.environ["AZURE_API_BASE"] = "" + os.environ["AZURE_API_VERSION"] = "" + os.environ["AZURE_API_KEY"] = "" ######### UNSET OpenAI vars for this ############## openai.api_type = "" @@ -884,11 +953,10 @@ def test_completion_azure_with_litellm_key(): # Add any assertions here to check the response print(response) - ######### RESET ENV VARs for this ################ - os.environ['AZURE_API_BASE'] = litellm.api_base - os.environ['AZURE_API_VERSION'] = litellm.api_version - os.environ['AZURE_API_KEY'] = litellm.api_key + os.environ["AZURE_API_BASE"] = litellm.api_base + os.environ["AZURE_API_VERSION"] = litellm.api_version + os.environ["AZURE_API_KEY"] = litellm.api_key ######### UNSET litellm vars litellm.api_type = None @@ -898,6 +966,8 @@ def test_completion_azure_with_litellm_key(): except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_azure() @@ -913,6 +983,8 @@ def test_completion_azure_deployment_id(): print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_azure_deployment_id() # Only works for local endpoint @@ -930,14 +1002,15 @@ def test_completion_azure_deployment_id(): # test_completion_anthropic_openai_proxy() + def test_completion_replicate_vicuna(): print("TESTING REPLICATE") - litellm.set_verbose=True + litellm.set_verbose = True model_name = "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b" try: response = completion( - model=model_name, - messages=messages, + model=model_name, + messages=messages, temperature=0.5, top_k=20, repetition_penalty=1, @@ -953,6 +1026,8 @@ def test_completion_replicate_vicuna(): pytest.fail(f"Error occurred: {e}") except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_replicate_vicuna() # commenting out - flaky test # def test_completion_replicate_llama2_stream(): @@ -960,20 +1035,20 @@ def test_completion_replicate_vicuna(): # model_name = "replicate/meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0" # try: # response = completion( -# model=model_name, +# model=model_name, # messages=[ # { # "role": "user", # "content": "what is yc write 1 paragraph", # } -# ], +# ], # stream=True, # max_tokens=20, # num_retries=3 # ) # print(f"response: {response}") # # Add any assertions here to check the response -# complete_response = "" +# complete_response = "" # for i, chunk in enumerate(response): # complete_response += chunk.choices[0].delta["content"] # # if i == 0: @@ -985,42 +1060,44 @@ def test_completion_replicate_vicuna(): # pytest.fail(f"Error occurred: {e}") # test_completion_replicate_llama2_stream() -def test_replicate_custom_prompt_dict(): + +def test_replicate_custom_prompt_dict(): litellm.set_verbose = True model_name = "replicate/meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0" litellm.register_prompt_template( model="replicate/meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0", - initial_prompt_value="You are a good assistant", # [OPTIONAL] + initial_prompt_value="You are a good assistant", # [OPTIONAL] roles={ "system": { - "pre_message": "[INST] <>\n", # [OPTIONAL] - "post_message": "\n<>\n [/INST]\n" # [OPTIONAL] + "pre_message": "[INST] <>\n", # [OPTIONAL] + "post_message": "\n<>\n [/INST]\n", # [OPTIONAL] + }, + "user": { + "pre_message": "[INST] ", # [OPTIONAL] + "post_message": " [/INST]", # [OPTIONAL] }, - "user": { - "pre_message": "[INST] ", # [OPTIONAL] - "post_message": " [/INST]" # [OPTIONAL] - }, "assistant": { - "pre_message": "\n", # [OPTIONAL] - "post_message": "\n" # [OPTIONAL] - } + "pre_message": "\n", # [OPTIONAL] + "post_message": "\n", # [OPTIONAL] + }, }, - final_prompt_value="Now answer as best you can:" # [OPTIONAL] + final_prompt_value="Now answer as best you can:", # [OPTIONAL] ) response = completion( - model=model_name, - messages=[ - { - "role": "user", - "content": "what is yc write 1 paragraph", - } - ], - num_retries=3 + model=model_name, + messages=[ + { + "role": "user", + "content": "what is yc write 1 paragraph", + } + ], + num_retries=3, ) print(f"response: {response}") - litellm.custom_prompt_dict = {} # reset + litellm.custom_prompt_dict = {} # reset -# test_replicate_custom_prompt_dict() + +# test_replicate_custom_prompt_dict() # commenthing this out since we won't be always testing a custom replicate deployment # def test_completion_replicate_deployments(): @@ -1029,8 +1106,8 @@ def test_replicate_custom_prompt_dict(): # model_name = "replicate/deployments/ishaan-jaff/ishaan-mistral" # try: # response = completion( -# model=model_name, -# messages=messages, +# model=model_name, +# messages=messages, # temperature=0.5, # seed=-1, # ) @@ -1045,61 +1122,88 @@ def test_replicate_custom_prompt_dict(): # test_completion_replicate_deployments() -######## Test TogetherAI ######## +######## Test TogetherAI ######## def test_completion_together_ai(): model_name = "together_ai/togethercomputer/CodeLlama-13b-Instruct" try: - messages =[ + messages = [ {"role": "user", "content": "Who are you"}, {"role": "assistant", "content": "I am your helpful assistant."}, {"role": "user", "content": "Tell me a joke"}, ] - response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn) + response = completion( + model=model_name, + messages=messages, + max_tokens=256, + n=1, + logger_fn=logger_fn, + ) # Add any assertions here to check the response print(response) cost = completion_cost(completion_response=response) - assert cost > 0.0 - print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}") + assert cost > 0.0 + print( + "Cost for completion call together-computer/llama-2-70b: ", + f"${float(cost):.10f}", + ) except Exception as e: pytest.fail(f"Error occurred: {e}") + def test_completion_together_ai_mixtral(): model_name = "together_ai/DiscoResearch/DiscoLM-mixtral-8x7b-v2" try: - messages =[ + messages = [ {"role": "user", "content": "Who are you"}, {"role": "assistant", "content": "I am your helpful assistant."}, {"role": "user", "content": "Tell me a joke"}, ] - response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn) + response = completion( + model=model_name, + messages=messages, + max_tokens=256, + n=1, + logger_fn=logger_fn, + ) # Add any assertions here to check the response print(response) cost = completion_cost(completion_response=response) - assert cost > 0.0 - print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}") + assert cost > 0.0 + print( + "Cost for completion call together-computer/llama-2-70b: ", + f"${float(cost):.10f}", + ) except litellm.Timeout as e: pass except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_together_ai_mixtral() + def test_completion_together_ai_yi_chat(): model_name = "together_ai/zero-one-ai/Yi-34B-Chat" try: - messages =[ + messages = [ {"role": "user", "content": "What llm are you?"}, ] response = completion(model=model_name, messages=messages) # Add any assertions here to check the response print(response) cost = completion_cost(completion_response=response) - assert cost > 0.0 - print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}") + assert cost > 0.0 + print( + "Cost for completion call together-computer/llama-2-70b: ", + f"${float(cost):.10f}", + ) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_together_ai_yi_chat() + # test_completion_together_ai() def test_customprompt_together_ai(): try: @@ -1110,8 +1214,21 @@ def test_customprompt_together_ai(): print(litellm._async_success_callback) response = completion( model="together_ai/mistralai/Mistral-7B-Instruct-v0.1", - messages=messages, - roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}} + messages=messages, + roles={ + "system": { + "pre_message": "<|im_start|>system\n", + "post_message": "<|im_end|>", + }, + "assistant": { + "pre_message": "<|im_start|>assistant\n", + "post_message": "<|im_end|>", + }, + "user": { + "pre_message": "<|im_start|>user\n", + "post_message": "<|im_end|>", + }, + }, ) print(response) except litellm.exceptions.Timeout as e: @@ -1121,14 +1238,16 @@ def test_customprompt_together_ai(): print(f"ERROR TYPE {type(e)}") pytest.fail(f"Error occurred: {e}") + # test_customprompt_together_ai() + def test_completion_sagemaker(): try: print("testing sagemaker") - litellm.set_verbose=True + litellm.set_verbose = True response = completion( - model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", messages=messages, temperature=0.2, max_tokens=80, @@ -1137,33 +1256,39 @@ def test_completion_sagemaker(): print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_sagemaker() + + +# test_completion_sagemaker() + def test_completion_chat_sagemaker(): try: messages = [{"role": "user", "content": "Hey, how's it going?"}] - litellm.set_verbose=True + litellm.set_verbose = True response = completion( - model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", messages=messages, max_tokens=100, temperature=0.7, stream=True, ) - # Add any assertions here to check the response - complete_response = "" + # Add any assertions here to check the response + complete_response = "" for chunk in response: - complete_response += chunk.choices[0].delta.content or "" + complete_response += chunk.choices[0].delta.content or "" print(f"complete_response: {complete_response}") assert len(complete_response) > 0 except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_chat_sagemaker() -def test_completion_chat_sagemaker_mistral(): - try: + +def test_completion_chat_sagemaker_mistral(): + try: messages = [{"role": "user", "content": "Hey, how's it going?"}] - + response = completion( model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-instruct", messages=messages, @@ -1171,19 +1296,20 @@ def test_completion_chat_sagemaker_mistral(): ) # Add any assertions here to check the response print(response) - except Exception as e: + except Exception as e: pytest.fail(f"An error occurred: {str(e)}") + # test_completion_chat_sagemaker_mistral() def test_completion_bedrock_titan(): try: response = completion( - model="bedrock/amazon.titan-tg1-large", + model="bedrock/amazon.titan-tg1-large", messages=messages, temperature=0.2, max_tokens=200, top_p=0.8, - logger_fn=logger_fn + logger_fn=logger_fn, ) # Add any assertions here to check the response print(response) @@ -1191,17 +1317,20 @@ def test_completion_bedrock_titan(): pass except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_bedrock_titan() + def test_completion_bedrock_claude(): print("calling claude") try: response = completion( - model="anthropic.claude-instant-v1", + model="anthropic.claude-instant-v1", messages=messages, max_tokens=10, temperature=0.1, - logger_fn=logger_fn + logger_fn=logger_fn, ) # Add any assertions here to check the response print(response) @@ -1209,18 +1338,21 @@ def test_completion_bedrock_claude(): pass except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_bedrock_claude() + def test_completion_bedrock_cohere(): print("calling bedrock cohere") litellm.set_verbose = True try: response = completion( - model="bedrock/cohere.command-text-v14", + model="bedrock/cohere.command-text-v14", messages=[{"role": "user", "content": "hi"}], temperature=0.1, max_tokens=10, - stream=True + stream=True, ) # Add any assertions here to check the response print(response) @@ -1230,6 +1362,8 @@ def test_completion_bedrock_cohere(): pass except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_bedrock_cohere() @@ -1245,10 +1379,9 @@ def test_completion_bedrock_claude_completion_auth(): os.environ["AWS_SECRET_ACCESS_KEY"] = "" os.environ["AWS_REGION_NAME"] = "" - try: response = completion( - model="bedrock/anthropic.claude-instant-v1", + model="bedrock/anthropic.claude-instant-v1", messages=messages, max_tokens=10, temperature=0.1, @@ -1267,6 +1400,8 @@ def test_completion_bedrock_claude_completion_auth(): pass except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_bedrock_claude_completion_auth() # def test_completion_bedrock_claude_external_client_auth(): @@ -1316,7 +1451,7 @@ def test_completion_bedrock_claude_completion_auth(): # litellm.set_verbose = False # try: # response = completion( -# model="bedrock/anthropic.claude-instant-v1", +# model="bedrock/anthropic.claude-instant-v1", # messages=messages, # stream=True # ) @@ -1334,13 +1469,13 @@ def test_completion_bedrock_claude_completion_auth(): # try: # litellm.set_verbose = False # response = completion( -# model="bedrock/ai21.j2-mid", +# model="bedrock/ai21.j2-mid", # messages=messages, # temperature=0.2, # top_p=0.2, # max_tokens=20 # ) -# # Add any assertions here to check the response +# # Add any assertions here to check the response # print(response) # except RateLimitError: # pass @@ -1352,7 +1487,7 @@ def test_completion_bedrock_claude_completion_auth(): # def test_completion_vllm(): # try: # response = completion( -# model="vllm/facebook/opt-125m", +# model="vllm/facebook/opt-125m", # messages=messages, # temperature=0.2, # max_tokens=80, @@ -1371,7 +1506,7 @@ def test_completion_bedrock_claude_completion_auth(): # try: # litellm.set_verbose = True # response = completion( -# model="facebook/opt-125m", +# model="facebook/opt-125m", # messages=messages, # temperature=0.2, # max_tokens=80, @@ -1391,7 +1526,7 @@ def test_completion_bedrock_claude_completion_auth(): # def test_completion_custom_api_base(): # try: # response = completion( -# model="custom/meta-llama/Llama-2-13b-hf", +# model="custom/meta-llama/Llama-2-13b-hf", # messages=messages, # temperature=0.2, # max_tokens=10, @@ -1405,6 +1540,7 @@ def test_completion_bedrock_claude_completion_auth(): # test_completion_custom_api_base() + def test_completion_with_fallbacks(): print(f"RUNNING TEST COMPLETION WITH FALLBACKS - test_completion_with_fallbacks") fallbacks = ["gpt-3.5-turbo", "gpt-3.5-turbo", "command-nightly"] @@ -1417,76 +1553,91 @@ def test_completion_with_fallbacks(): except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_with_fallbacks() def test_completion_anyscale_api(): try: # litellm.set_verbose=True - messages=[{ - "role": "system", - "content": "You're a good bot" - },{ - "role": "user", - "content": "Hey", - },{ - "role": "user", - "content": "Hey", - }] - response = completion( - model="anyscale/meta-llama/Llama-2-7b-chat-hf", - messages=messages,) - print(response) - except Exception as e: - pytest.fail(f"Error occurred: {e}") - -# test_completion_anyscale_api() - -def test_azure_cloudflare_api(): - try: messages = [ - { - "role": "user", - "content": "How do I output all files in a directory using Python?", - }, - ] - response = completion(model="azure/gpt-turbo", messages=messages, base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), api_key=os.getenv("AZURE_FRANCE_API_KEY")) - print(f"response: {response}") - except Exception as e: - traceback.print_exc() - pass - -# test_azure_cloudflare_api() - -def test_completion_anyscale_2(): - try: - # litellm.set_verbose=True - messages=[{ - "role": "system", - "content": "You're a good bot" - },{ - "role": "user", - "content": "Hey", - },{ - "role": "user", - "content": "Hey", - }] + {"role": "system", "content": "You're a good bot"}, + { + "role": "user", + "content": "Hey", + }, + { + "role": "user", + "content": "Hey", + }, + ] response = completion( - model="anyscale/meta-llama/Llama-2-7b-chat-hf", - messages=messages + model="anyscale/meta-llama/Llama-2-7b-chat-hf", + messages=messages, ) print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + +# test_completion_anyscale_api() + + +def test_azure_cloudflare_api(): + try: + messages = [ + { + "role": "user", + "content": "How do I output all files in a directory using Python?", + }, + ] + response = completion( + model="azure/gpt-turbo", + messages=messages, + base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), + api_key=os.getenv("AZURE_FRANCE_API_KEY"), + ) + print(f"response: {response}") + except Exception as e: + traceback.print_exc() + pass + + +# test_azure_cloudflare_api() + + +def test_completion_anyscale_2(): + try: + # litellm.set_verbose=True + messages = [ + {"role": "system", "content": "You're a good bot"}, + { + "role": "user", + "content": "Hey", + }, + { + "role": "user", + "content": "Hey", + }, + ] + response = completion( + model="anyscale/meta-llama/Llama-2-7b-chat-hf", messages=messages + ) + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_mistral_anyscale_stream(): - litellm.set_verbose=False + litellm.set_verbose = False response = completion( - model = 'anyscale/mistralai/Mistral-7B-Instruct-v0.1', - messages = [{ "content": "hello, good morning","role": "user"}], + model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", + messages=[{"content": "hello, good morning", "role": "user"}], stream=True, ) for chunk in response: # print(chunk) print(chunk["choices"][0]["delta"].get("content", ""), end="") + + # test_mistral_anyscale_stream() # test_completion_anyscale_2() # def test_completion_with_fallbacks_multiple_keys(): @@ -1504,7 +1655,7 @@ def test_mistral_anyscale_stream(): # error_str = traceback.format_exc() # pytest.fail(f"Error occurred: {error_str}") -# test_completion_with_fallbacks_multiple_keys() +# test_completion_with_fallbacks_multiple_keys() # def test_petals(): # try: # response = completion(model="petals-team/StableBeluga2", messages=messages) @@ -1577,54 +1728,60 @@ def test_mistral_anyscale_stream(): #### Test A121 ################### def test_completion_ai21(): print("running ai21 j2light test") - litellm.set_verbose=True + litellm.set_verbose = True model_name = "j2-light" - try: - response = completion(model=model_name, messages=messages, max_tokens=100, temperature=0.8) - # Add any assertions here to check the response - print(response) - except Exception as e: - pytest.fail(f"Error occurred: {e}") - -# test_completion_ai21() -# test_completion_ai21() -## test deep infra -def test_completion_deep_infra(): - litellm.set_verbose = False - model_name = "deepinfra/meta-llama/Llama-2-70b-chat-hf" try: response = completion( - model=model_name, - messages=messages, - temperature=0, - max_tokens=10 + model=model_name, messages=messages, max_tokens=100, temperature=0.8 ) # Add any assertions here to check the response print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + + +# test_completion_ai21() +# test_completion_ai21() +## test deep infra +def test_completion_deep_infra(): + litellm.set_verbose = False + model_name = "deepinfra/meta-llama/Llama-2-70b-chat-hf" + try: + response = completion( + model=model_name, messages=messages, temperature=0, max_tokens=10 + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # test_completion_deep_infra() + def test_completion_deep_infra_mistral(): print("deep infra test with temp=0") model_name = "deepinfra/mistralai/Mistral-7B-Instruct-v0.1" try: response = completion( - model=model_name, + model=model_name, messages=messages, - temperature=0.01, # mistrail fails with temperature=0 - max_tokens=10 + temperature=0.01, # mistrail fails with temperature=0 + max_tokens=10, ) # Add any assertions here to check the response print(response) - except litellm.exceptions.Timeout as e: + except litellm.exceptions.Timeout as e: pass except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_deep_infra_mistral() + # Gemini tests -def test_completion_gemini(): +def test_completion_gemini(): litellm.set_verbose = True model_name = "gemini/gemini-pro" messages = [{"role": "user", "content": "Hey, how's it going?"}] @@ -1634,7 +1791,11 @@ def test_completion_gemini(): print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_gemini() + + +test_completion_gemini() + + # Palm tests def test_completion_palm(): litellm.set_verbose = True @@ -1646,25 +1807,30 @@ def test_completion_palm(): print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_palm() + # test palm with streaming def test_completion_palm_stream(): # litellm.set_verbose = True model_name = "palm/chat-bison" try: response = completion( - model=model_name, + model=model_name, messages=messages, stop=["stop"], stream=True, - max_tokens=20 + max_tokens=20, ) # Add any assertions here to check the response for chunk in response: print(chunk) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_palm_stream() # test_completion_deep_infra() @@ -1691,14 +1857,16 @@ def test_completion_palm_stream(): # pytest.fail(f"Error occurred: {e}") # test_maritalk() + def test_completion_together_ai_stream(): user_message = "Write 1pg about YC & litellm" - messages = [{ "content": user_message,"role": "user"}] + messages = [{"content": user_message, "role": "user"}] try: response = completion( - model="together_ai/mistralai/Mistral-7B-Instruct-v0.1", - messages=messages, stream=True, - max_tokens=5 + model="together_ai/mistralai/Mistral-7B-Instruct-v0.1", + messages=messages, + stream=True, + max_tokens=5, ) print(response) for chunk in response: @@ -1706,6 +1874,8 @@ def test_completion_together_ai_stream(): # print(string_response) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_together_ai_stream() @@ -1716,14 +1886,17 @@ def test_completion_together_ai_stream(): # test_completion_together_ai_stream() + def test_moderation(): import openai - openai.api_type = "azure" + + openai.api_type = "azure" openai.api_version = "GM" - response = litellm.moderation(input="i'm ishaan cto of litellm") + response = litellm.moderation(input="i'm ishaan cto of litellm") print(response) output = response.results[0] print(output) return output -# test_moderation() \ No newline at end of file + +# test_moderation() diff --git a/litellm/tests/test_completion_with_retries.py b/litellm/tests/test_completion_with_retries.py index 416051e38..422794531 100644 --- a/litellm/tests/test_completion_with_retries.py +++ b/litellm/tests/test_completion_with_retries.py @@ -28,6 +28,7 @@ def logger_fn(user_model_dict): # print(f"user_model_dict: {user_model_dict}") pass + # normal call def test_completion_custom_provider_model_name(): try: @@ -41,25 +42,31 @@ def test_completion_custom_provider_model_name(): except Exception as e: pytest.fail(f"Error occurred: {e}") + # completion with num retries + impact on exception mapping -def test_completion_with_num_retries(): - try: - response = completion(model="j2-ultra", messages=[{"messages": "vibe", "bad": "message"}], num_retries=2) +def test_completion_with_num_retries(): + try: + response = completion( + model="j2-ultra", + messages=[{"messages": "vibe", "bad": "message"}], + num_retries=2, + ) pytest.fail(f"Unmapped exception occurred") - except Exception as e: + except Exception as e: pass + # test_completion_with_num_retries() def test_completion_with_0_num_retries(): try: - litellm.set_verbose=False + litellm.set_verbose = False print("making request") # Use the completion function response = completion( model="gpt-3.5-turbo", messages=[{"gm": "vibe", "role": "user"}], - max_retries=4 + max_retries=4, ) print(response) @@ -69,5 +76,6 @@ def test_completion_with_0_num_retries(): print("exception", e) pass + # Call the test function -test_completion_with_0_num_retries() \ No newline at end of file +test_completion_with_0_num_retries() diff --git a/litellm/tests/test_config.py b/litellm/tests/test_config.py index ceecaf181..69e37cf87 100644 --- a/litellm/tests/test_config.py +++ b/litellm/tests/test_config.py @@ -15,77 +15,104 @@ from litellm import completion_with_config config = { "default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"], "model": { - "claude-instant-1": { - "needs_moderation": True - }, + "claude-instant-1": {"needs_moderation": True}, "gpt-3.5-turbo": { "error_handling": { - "ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"} + "ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"} } - } - } + }, + }, } + def test_config_context_window_exceeded(): try: sample_text = "how does a court case get to the Supreme Court?" * 1000 messages = [{"content": sample_text, "role": "user"}] - response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config) + response = completion_with_config( + model="gpt-3.5-turbo", messages=messages, config=config + ) print(response) except Exception as e: print(f"Exception: {e}") pytest.fail(f"An exception occurred: {e}") -# test_config_context_window_exceeded() + +# test_config_context_window_exceeded() + def test_config_context_moderation(): try: - messages=[{"role": "user", "content": "I want to kill them."}] - response = completion_with_config(model="claude-instant-1", messages=messages, config=config) + messages = [{"role": "user", "content": "I want to kill them."}] + response = completion_with_config( + model="claude-instant-1", messages=messages, config=config + ) print(response) except Exception as e: print(f"Exception: {e}") pytest.fail(f"An exception occurred: {e}") -# test_config_context_moderation() + +# test_config_context_moderation() + def test_config_context_default_fallback(): try: - messages=[{"role": "user", "content": "Hey, how's it going?"}] - response = completion_with_config(model="claude-instant-1", messages=messages, config=config, api_key="bad-key") + messages = [{"role": "user", "content": "Hey, how's it going?"}] + response = completion_with_config( + model="claude-instant-1", + messages=messages, + config=config, + api_key="bad-key", + ) print(response) except Exception as e: print(f"Exception: {e}") pytest.fail(f"An exception occurred: {e}") -# test_config_context_default_fallback() + +# test_config_context_default_fallback() config = { "default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"], - "available_models": ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-0314", "gpt-4-0613", - "j2-ultra", "command-nightly", "togethercomputer/llama-2-70b-chat", "chat-bison", "chat-bison@001", "claude-2"], - "adapt_to_prompt_size": True, # type: ignore + "available_models": [ + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "j2-ultra", + "command-nightly", + "togethercomputer/llama-2-70b-chat", + "chat-bison", + "chat-bison@001", + "claude-2", + ], + "adapt_to_prompt_size": True, # type: ignore "model": { - "claude-instant-1": { - "needs_moderation": True - }, + "claude-instant-1": {"needs_moderation": True}, "gpt-3.5-turbo": { "error_handling": { - "ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"} + "ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"} } - } - } + }, + }, } + def test_config_context_adapt_to_prompt(): try: sample_text = "how does a court case get to the Supreme Court?" * 1000 messages = [{"content": sample_text, "role": "user"}] - response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config) + response = completion_with_config( + model="gpt-3.5-turbo", messages=messages, config=config + ) print(response) except Exception as e: print(f"Exception: {e}") pytest.fail(f"An exception occurred: {e}") -test_config_context_adapt_to_prompt() \ No newline at end of file + +test_config_context_adapt_to_prompt() diff --git a/litellm/tests/test_configs/custom_auth.py b/litellm/tests/test_configs/custom_auth.py index b37ff8370..f3825038e 100644 --- a/litellm/tests/test_configs/custom_auth.py +++ b/litellm/tests/test_configs/custom_auth.py @@ -1,14 +1,16 @@ from litellm.proxy._types import UserAPIKeyAuth from fastapi import Request from dotenv import load_dotenv -import os +import os load_dotenv() -async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: - try: + + +async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: + try: print(f"api_key: {api_key}") if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234": return UserAPIKeyAuth(api_key=api_key) raise Exception - except: - raise Exception \ No newline at end of file + except: + raise Exception diff --git a/litellm/tests/test_configs/custom_callbacks.py b/litellm/tests/test_configs/custom_callbacks.py index 7aa1577f6..226908413 100644 --- a/litellm/tests/test_configs/custom_callbacks.py +++ b/litellm/tests/test_configs/custom_callbacks.py @@ -2,30 +2,35 @@ from litellm.integrations.custom_logger import CustomLogger import inspect import litellm + class testCustomCallbackProxy(CustomLogger): def __init__(self): - self.success: bool = False # type: ignore - self.failure: bool = False # type: ignore - self.async_success: bool = False # type: ignore + self.success: bool = False # type: ignore + self.failure: bool = False # type: ignore + self.async_success: bool = False # type: ignore self.async_success_embedding: bool = False # type: ignore - self.async_failure: bool = False # type: ignore + self.async_failure: bool = False # type: ignore self.async_failure_embedding: bool = False # type: ignore - self.async_completion_kwargs = None # type: ignore - self.async_embedding_kwargs = None # type: ignore - self.async_embedding_response = None # type: ignore + self.async_completion_kwargs = None # type: ignore + self.async_embedding_kwargs = None # type: ignore + self.async_embedding_response = None # type: ignore - self.async_completion_kwargs_fail = None # type: ignore - self.async_embedding_kwargs_fail = None # type: ignore + self.async_completion_kwargs_fail = None # type: ignore + self.async_embedding_kwargs_fail = None # type: ignore - self.streaming_response_obj = None # type: ignore + self.streaming_response_obj = None # type: ignore blue_color_code = "\033[94m" reset_color_code = "\033[0m" print(f"{blue_color_code}Initialized LiteLLM custom logger") try: print(f"Logger Initialized with following methods:") - methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))] - + methods = [ + method + for method in dir(self) + if inspect.ismethod(getattr(self, method)) + ] + # Pretty print the methods for method in methods: print(f" - {method}") @@ -33,29 +38,32 @@ class testCustomCallbackProxy(CustomLogger): except: pass - def log_pre_api_call(self, model, messages, kwargs): + def log_pre_api_call(self, model, messages, kwargs): print(f"Pre-API Call") - - def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): print(f"Post-API Call") - + def log_stream_event(self, kwargs, response_obj, start_time, end_time): print(f"On Stream") - - def log_success_event(self, kwargs, response_obj, start_time, end_time): + + def log_success_event(self, kwargs, response_obj, start_time, end_time): print(f"On Success") self.success = True - def log_failure_event(self, kwargs, response_obj, start_time, end_time): + def log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Failure") self.failure = True - async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): print(f"On Async success") self.async_success = True print("Value of async success: ", self.async_success) print("\n kwargs: ", kwargs) - if kwargs.get("model") == "azure-embedding-model" or kwargs.get("model") == "ada": + if ( + kwargs.get("model") == "azure-embedding-model" + or kwargs.get("model") == "ada" + ): print("Got an embedding model", kwargs.get("model")) print("Setting embedding success to True") self.async_success_embedding = True @@ -65,7 +73,6 @@ class testCustomCallbackProxy(CustomLogger): if kwargs.get("stream") == True: self.streaming_response_obj = response_obj - self.async_completion_kwargs = kwargs model = kwargs.get("model", None) @@ -74,17 +81,18 @@ class testCustomCallbackProxy(CustomLogger): # Access litellm_params passed to litellm.completion(), example access `metadata` litellm_params = kwargs.get("litellm_params", {}) - metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here + metadata = litellm_params.get( + "metadata", {} + ) # headers passed to LiteLLM proxy, can be found here # Calculate cost using litellm.completion_cost() cost = litellm.completion_cost(completion_response=response_obj) response = response_obj - # tokens used in response + # tokens used in response usage = response_obj["usage"] print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger)) - print( f""" Model: {model}, @@ -98,8 +106,7 @@ class testCustomCallbackProxy(CustomLogger): ) return - - async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Async Failure") self.async_failure = True print("Value of async failure: ", self.async_failure) @@ -107,7 +114,8 @@ class testCustomCallbackProxy(CustomLogger): if kwargs.get("model") == "text-embedding-ada-002": self.async_failure_embedding = True self.async_embedding_kwargs_fail = kwargs - + self.async_completion_kwargs_fail = kwargs -my_custom_logger = testCustomCallbackProxy() \ No newline at end of file + +my_custom_logger = testCustomCallbackProxy() diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index d193c53b6..6bd2656e2 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -3,7 +3,8 @@ import sys, os, time, inspect, asyncio, traceback from datetime import datetime import pytest -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) from typing import Optional, Literal, List, Union from litellm import completion, embedding, Cache import litellm @@ -16,216 +17,284 @@ from litellm.integrations.custom_logger import CustomLogger ## 4: On LiteLLM Call failure ## 5. Caching -# Test models -## 1. OpenAI -## 2. Azure OpenAI +# Test models +## 1. OpenAI +## 2. Azure OpenAI ## 3. Non-OpenAI/Azure - e.g. Bedrock # Test interfaces ## 1. litellm.completion() + litellm.embeddings() ## refer to test_custom_callback_input_router.py for the router + proxy tests -class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + +class CompletionCustomHandler( + CustomLogger +): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class """ - The set of expected inputs to a custom handler for a + The set of expected inputs to a custom handler for a """ + # Class variables or attributes def __init__(self): self.errors = [] - self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = [] - - def log_pre_api_call(self, model, messages, kwargs): - try: + self.states: Optional[ + List[ + Literal[ + "sync_pre_api_call", + "async_pre_api_call", + "post_api_call", + "sync_stream", + "async_stream", + "sync_success", + "async_success", + "sync_failure", + "async_failure", + ] + ] + ] = [] + + def log_pre_api_call(self, model, messages, kwargs): + try: self.states.append("sync_pre_api_call") ## MODEL assert isinstance(model, str) ## MESSAGES assert isinstance(messages, list) ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - except Exception as e: + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + except Exception as e: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) - def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): try: self.states.append("post_api_call") - ## START TIME + ## START TIME assert isinstance(start_time, datetime) - ## END TIME + ## END TIME assert end_time == None - ## RESPONSE OBJECT + ## RESPONSE OBJECT assert response_obj == None - ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert isinstance(kwargs['input'], (list, dict, str)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response']) - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) - except: - print(f"Assertion Error: {traceback.format_exc()}") - self.errors.append(traceback.format_exc()) - - async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): - try: - self.states.append("async_stream") - ## START TIME - assert isinstance(start_time, datetime) - ## END TIME - assert isinstance(end_time, datetime) - ## RESPONSE OBJECT - assert isinstance(response_obj, litellm.ModelResponse) ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) - except: - print(f"Assertion Error: {traceback.format_exc()}") - self.errors.append(traceback.format_exc()) - - def log_success_event(self, kwargs, response_obj, start_time, end_time): - try: - self.states.append("sync_success") - ## START TIME - assert isinstance(start_time, datetime) - ## END TIME - assert isinstance(end_time, datetime) - ## RESPONSE OBJECT - assert isinstance(response_obj, litellm.ModelResponse) - ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert isinstance(kwargs["input"], (list, dict, str)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert ( + isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + or inspect.iscoroutine(kwargs["original_response"]) + or inspect.isasyncgen(kwargs["original_response"]) + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) - def log_failure_event(self, kwargs, response_obj, start_time, end_time): + async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): try: - self.states.append("sync_failure") - ## START TIME + self.states.append("async_stream") + ## START TIME assert isinstance(start_time, datetime) - ## END TIME + ## END TIME assert isinstance(end_time, datetime) - ## RESPONSE OBJECT - assert response_obj == None + ## RESPONSE OBJECT + assert isinstance(response_obj, litellm.ModelResponse) ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) - except: + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) and isinstance( + kwargs["messages"][0], dict + ) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert ( + isinstance(kwargs["input"], list) + and isinstance(kwargs["input"][0], dict) + ) or isinstance(kwargs["input"], (dict, str)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert ( + isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + or inspect.isasyncgen(kwargs["original_response"]) + or inspect.iscoroutine(kwargs["original_response"]) + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) + except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) - + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + self.states.append("sync_success") + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert isinstance(response_obj, litellm.ModelResponse) + ## KWARGS + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) and isinstance( + kwargs["messages"][0], dict + ) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert ( + isinstance(kwargs["input"], list) + and isinstance(kwargs["input"][0], dict) + ) or isinstance(kwargs["input"], (dict, str)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + self.states.append("sync_failure") + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert response_obj == None + ## KWARGS + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) and isinstance( + kwargs["messages"][0], dict + ) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert ( + isinstance(kwargs["input"], list) + and isinstance(kwargs["input"][0], dict) + ) or isinstance(kwargs["input"], (dict, str)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert ( + isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + or kwargs["original_response"] == None + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + async def async_log_pre_api_call(self, model, messages, kwargs): - try: + try: self.states.append("async_pre_api_call") ## MODEL assert isinstance(model, str) ## MESSAGES assert isinstance(messages, list) and isinstance(messages[0], dict) ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - except Exception as e: + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) and isinstance( + kwargs["messages"][0], dict + ) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + except Exception as e: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - try: + try: self.states.append("async_success") - ## START TIME + ## START TIME assert isinstance(start_time, datetime) - ## END TIME + ## END TIME assert isinstance(end_time, datetime) - ## RESPONSE OBJECT - assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)) + ## RESPONSE OBJECT + assert isinstance( + response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse) + ) ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert isinstance(kwargs['input'], (list, dict, str)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert isinstance(kwargs["input"], (list, dict, str)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert ( + isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + or inspect.isasyncgen(kwargs["original_response"]) + or inspect.iscoroutine(kwargs["original_response"]) + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) - except: + except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: self.states.append("async_failure") - ## START TIME + ## START TIME assert isinstance(start_time, datetime) - ## END TIME + ## END TIME assert isinstance(end_time, datetime) - ## RESPONSE OBJECT + ## RESPONSE OBJECT assert response_obj == None ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert isinstance(kwargs['input'], (list, str, dict)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) - except: + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert isinstance(kwargs["input"], (list, str, dict)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert ( + isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + or inspect.isasyncgen(kwargs["original_response"]) + or kwargs["original_response"] == None + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) + except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) @@ -233,33 +302,30 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse # COMPLETION ## Test OpenAI + sync def test_chat_openai_stream(): - try: + try: customHandler = CompletionCustomHandler() litellm.callbacks = [customHandler] - response = litellm.completion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm sync openai" - }]) + response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm sync openai"}], + ) ## test streaming - response = litellm.completion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm openai" - }], - stream=True) - for chunk in response: + response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + stream=True, + ) + for chunk in response: continue ## test failure callback - try: - response = litellm.completion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm openai" - }], - api_key="my-bad-key", - stream=True) - for chunk in response: + try: + response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + api_key="my-bad-key", + stream=True, + ) + for chunk in response: continue except: pass @@ -267,41 +333,40 @@ def test_chat_openai_stream(): print(f"customHandler.errors: {customHandler.errors}") assert len(customHandler.errors) == 0 litellm.callbacks = [] - except Exception as e: + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # test_chat_openai_stream() + ## Test OpenAI + Async @pytest.mark.asyncio async def test_async_chat_openai_stream(): - try: + try: customHandler = CompletionCustomHandler() litellm.callbacks = [customHandler] - response = await litellm.acompletion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm openai" - }]) + response = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + ) ## test streaming - response = await litellm.acompletion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm openai" - }], - stream=True) - async for chunk in response: + response = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + stream=True, + ) + async for chunk in response: continue ## test failure callback - try: - response = await litellm.acompletion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm openai" - }], - api_key="my-bad-key", - stream=True) - async for chunk in response: + try: + response = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + api_key="my-bad-key", + stream=True, + ) + async for chunk in response: continue except: pass @@ -309,40 +374,39 @@ async def test_async_chat_openai_stream(): print(f"customHandler.errors: {customHandler.errors}") assert len(customHandler.errors) == 0 litellm.callbacks = [] - except Exception as e: + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # asyncio.run(test_async_chat_openai_stream()) + ## Test Azure + sync def test_chat_azure_stream(): - try: + try: customHandler = CompletionCustomHandler() litellm.callbacks = [customHandler] - response = litellm.completion(model="azure/chatgpt-v-2", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm sync azure" - }]) + response = litellm.completion( + model="azure/chatgpt-v-2", + messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}], + ) # test streaming - response = litellm.completion(model="azure/chatgpt-v-2", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm sync azure" - }], - stream=True) - for chunk in response: + response = litellm.completion( + model="azure/chatgpt-v-2", + messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}], + stream=True, + ) + for chunk in response: continue # test failure callback - try: - response = litellm.completion(model="azure/chatgpt-v-2", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm sync azure" - }], - api_key="my-bad-key", - stream=True) - for chunk in response: + try: + response = litellm.completion( + model="azure/chatgpt-v-2", + messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}], + api_key="my-bad-key", + stream=True, + ) + for chunk in response: continue except: pass @@ -350,41 +414,40 @@ def test_chat_azure_stream(): print(f"customHandler.errors: {customHandler.errors}") assert len(customHandler.errors) == 0 litellm.callbacks = [] - except Exception as e: + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # test_chat_azure_stream() + ## Test Azure + Async @pytest.mark.asyncio async def test_async_chat_azure_stream(): - try: + try: customHandler = CompletionCustomHandler() litellm.callbacks = [customHandler] - response = await litellm.acompletion(model="azure/chatgpt-v-2", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm async azure" - }]) + response = await litellm.acompletion( + model="azure/chatgpt-v-2", + messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}], + ) ## test streaming - response = await litellm.acompletion(model="azure/chatgpt-v-2", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm async azure" - }], - stream=True) - async for chunk in response: + response = await litellm.acompletion( + model="azure/chatgpt-v-2", + messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}], + stream=True, + ) + async for chunk in response: continue ## test failure callback - try: - response = await litellm.acompletion(model="azure/chatgpt-v-2", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm async azure" - }], - api_key="my-bad-key", - stream=True) - async for chunk in response: + try: + response = await litellm.acompletion( + model="azure/chatgpt-v-2", + messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}], + api_key="my-bad-key", + stream=True, + ) + async for chunk in response: continue except: pass @@ -392,40 +455,39 @@ async def test_async_chat_azure_stream(): print(f"customHandler.errors: {customHandler.errors}") assert len(customHandler.errors) == 0 litellm.callbacks = [] - except Exception as e: + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # asyncio.run(test_async_chat_azure_stream()) + ## Test Bedrock + sync def test_chat_bedrock_stream(): - try: + try: customHandler = CompletionCustomHandler() litellm.callbacks = [customHandler] - response = litellm.completion(model="bedrock/anthropic.claude-v1", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm sync bedrock" - }]) + response = litellm.completion( + model="bedrock/anthropic.claude-v1", + messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}], + ) # test streaming - response = litellm.completion(model="bedrock/anthropic.claude-v1", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm sync bedrock" - }], - stream=True) - for chunk in response: + response = litellm.completion( + model="bedrock/anthropic.claude-v1", + messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}], + stream=True, + ) + for chunk in response: continue # test failure callback - try: - response = litellm.completion(model="bedrock/anthropic.claude-v1", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm sync bedrock" - }], - aws_region_name="my-bad-region", - stream=True) - for chunk in response: + try: + response = litellm.completion( + model="bedrock/anthropic.claude-v1", + messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}], + aws_region_name="my-bad-region", + stream=True, + ) + for chunk in response: continue except: pass @@ -433,43 +495,42 @@ def test_chat_bedrock_stream(): print(f"customHandler.errors: {customHandler.errors}") assert len(customHandler.errors) == 0 litellm.callbacks = [] - except Exception as e: + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # test_chat_bedrock_stream() + ## Test Bedrock + Async @pytest.mark.asyncio async def test_async_chat_bedrock_stream(): - try: + try: customHandler = CompletionCustomHandler() litellm.callbacks = [customHandler] - response = await litellm.acompletion(model="bedrock/anthropic.claude-v1", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm async bedrock" - }]) + response = await litellm.acompletion( + model="bedrock/anthropic.claude-v1", + messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}], + ) # test streaming - response = await litellm.acompletion(model="bedrock/anthropic.claude-v1", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm async bedrock" - }], - stream=True) + response = await litellm.acompletion( + model="bedrock/anthropic.claude-v1", + messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}], + stream=True, + ) print(f"response: {response}") - async for chunk in response: + async for chunk in response: print(f"chunk: {chunk}") continue ## test failure callback - try: - response = await litellm.acompletion(model="bedrock/anthropic.claude-v1", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm async bedrock" - }], - aws_region_name="my-bad-key", - stream=True) - async for chunk in response: + try: + response = await litellm.acompletion( + model="bedrock/anthropic.claude-v1", + messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}], + aws_region_name="my-bad-key", + stream=True, + ) + async for chunk in response: continue except: pass @@ -477,155 +538,190 @@ async def test_async_chat_bedrock_stream(): print(f"customHandler.errors: {customHandler.errors}") assert len(customHandler.errors) == 0 litellm.callbacks = [] - except Exception as e: + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # asyncio.run(test_async_chat_bedrock_stream()) + # EMBEDDING ## Test OpenAI + Async @pytest.mark.asyncio async def test_async_embedding_openai(): - try: + try: customHandler_success = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler() litellm.callbacks = [customHandler_success] - response = await litellm.aembedding(model="azure/azure-embedding-model", - input=["good morning from litellm"]) + response = await litellm.aembedding( + model="azure/azure-embedding-model", input=["good morning from litellm"] + ) await asyncio.sleep(1) print(f"customHandler_success.errors: {customHandler_success.errors}") print(f"customHandler_success.states: {customHandler_success.states}") assert len(customHandler_success.errors) == 0 - assert len(customHandler_success.states) == 3 # pre, post, success + assert len(customHandler_success.states) == 3 # pre, post, success # test failure callback litellm.callbacks = [customHandler_failure] - try: - response = await litellm.aembedding(model="text-embedding-ada-002", - input=["good morning from litellm"], - api_key="my-bad-key") + try: + response = await litellm.aembedding( + model="text-embedding-ada-002", + input=["good morning from litellm"], + api_key="my-bad-key", + ) except: pass await asyncio.sleep(1) print(f"customHandler_failure.errors: {customHandler_failure.errors}") print(f"customHandler_failure.states: {customHandler_failure.states}") assert len(customHandler_failure.errors) == 0 - assert len(customHandler_failure.states) == 3 # pre, post, failure - except Exception as e: + assert len(customHandler_failure.states) == 3 # pre, post, failure + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # asyncio.run(test_async_embedding_openai()) + ## Test Azure + Async @pytest.mark.asyncio async def test_async_embedding_azure(): - try: + try: customHandler_success = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler() litellm.callbacks = [customHandler_success] - response = await litellm.aembedding(model="azure/azure-embedding-model", - input=["good morning from litellm"]) + response = await litellm.aembedding( + model="azure/azure-embedding-model", input=["good morning from litellm"] + ) await asyncio.sleep(1) print(f"customHandler_success.errors: {customHandler_success.errors}") print(f"customHandler_success.states: {customHandler_success.states}") assert len(customHandler_success.errors) == 0 - assert len(customHandler_success.states) == 3 # pre, post, success + assert len(customHandler_success.states) == 3 # pre, post, success # test failure callback litellm.callbacks = [customHandler_failure] - try: - response = await litellm.aembedding(model="azure/azure-embedding-model", - input=["good morning from litellm"], - api_key="my-bad-key") + try: + response = await litellm.aembedding( + model="azure/azure-embedding-model", + input=["good morning from litellm"], + api_key="my-bad-key", + ) except: pass await asyncio.sleep(1) print(f"customHandler_failure.errors: {customHandler_failure.errors}") print(f"customHandler_failure.states: {customHandler_failure.states}") assert len(customHandler_failure.errors) == 0 - assert len(customHandler_failure.states) == 3 # pre, post, success - except Exception as e: + assert len(customHandler_failure.states) == 3 # pre, post, success + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # asyncio.run(test_async_embedding_azure()) + ## Test Bedrock + Async @pytest.mark.asyncio async def test_async_embedding_bedrock(): - try: + try: customHandler_success = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler() litellm.callbacks = [customHandler_success] litellm.set_verbose = True - response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3", - input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2") + response = await litellm.aembedding( + model="bedrock/cohere.embed-multilingual-v3", + input=["good morning from litellm"], + aws_region_name="os.environ/AWS_REGION_NAME_2", + ) await asyncio.sleep(1) print(f"customHandler_success.errors: {customHandler_success.errors}") print(f"customHandler_success.states: {customHandler_success.states}") assert len(customHandler_success.errors) == 0 - assert len(customHandler_success.states) == 3 # pre, post, success + assert len(customHandler_success.states) == 3 # pre, post, success # test failure callback litellm.callbacks = [customHandler_failure] - try: - response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3", - input=["good morning from litellm"], - aws_region_name="my-bad-region") + try: + response = await litellm.aembedding( + model="bedrock/cohere.embed-multilingual-v3", + input=["good morning from litellm"], + aws_region_name="my-bad-region", + ) except: pass await asyncio.sleep(1) print(f"customHandler_failure.errors: {customHandler_failure.errors}") print(f"customHandler_failure.states: {customHandler_failure.states}") assert len(customHandler_failure.errors) == 0 - assert len(customHandler_failure.states) == 3 # pre, post, success - except Exception as e: + assert len(customHandler_failure.states) == 3 # pre, post, success + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # asyncio.run(test_async_embedding_bedrock()) -# CACHING + +# CACHING ## Test Azure - completion, embedding @pytest.mark.asyncio async def test_async_completion_azure_caching(): customHandler_caching = CompletionCustomHandler() - litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) litellm.callbacks = [customHandler_caching] unique_time = time.time() - response1 = await litellm.acompletion(model="azure/chatgpt-v-2", - messages=[{ - "role": "user", - "content": f"Hi 👋 - i'm async azure {unique_time}" - }], - caching=True) + response1 = await litellm.acompletion( + model="azure/chatgpt-v-2", + messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], + caching=True, + ) await asyncio.sleep(1) print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}") - response2 = await litellm.acompletion(model="azure/chatgpt-v-2", - messages=[{ - "role": "user", - "content": f"Hi 👋 - i'm async azure {unique_time}" - }], - caching=True) - await asyncio.sleep(1) # success callbacks are done in parallel - print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}") + response2 = await litellm.acompletion( + model="azure/chatgpt-v-2", + messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], + caching=True, + ) + await asyncio.sleep(1) # success callbacks are done in parallel + print( + f"customHandler_caching.states post-cache hit: {customHandler_caching.states}" + ) assert len(customHandler_caching.errors) == 0 - assert len(customHandler_caching.states) == 4 # pre, post, success, success + assert len(customHandler_caching.states) == 4 # pre, post, success, success + @pytest.mark.asyncio async def test_async_embedding_azure_caching(): print("Testing custom callback input - Azure Caching") customHandler_caching = CompletionCustomHandler() - litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) litellm.callbacks = [customHandler_caching] unique_time = time.time() - response1 = await litellm.aembedding(model="azure/azure-embedding-model", - input=[f"good morning from litellm1 {unique_time}"], - caching=True) - await asyncio.sleep(1) # set cache is async for aembedding() - response2 = await litellm.aembedding(model="azure/azure-embedding-model", - input=[f"good morning from litellm1 {unique_time}"], - caching=True) - await asyncio.sleep(1) # success callbacks are done in parallel + response1 = await litellm.aembedding( + model="azure/azure-embedding-model", + input=[f"good morning from litellm1 {unique_time}"], + caching=True, + ) + await asyncio.sleep(1) # set cache is async for aembedding() + response2 = await litellm.aembedding( + model="azure/azure-embedding-model", + input=[f"good morning from litellm1 {unique_time}"], + caching=True, + ) + await asyncio.sleep(1) # success callbacks are done in parallel print(customHandler_caching.states) assert len(customHandler_caching.errors) == 0 - assert len(customHandler_caching.states) == 4 # pre, post, success, success + assert len(customHandler_caching.states) == 4 # pre, post, success, success + # asyncio.run( # test_async_embedding_azure_caching() -# ) \ No newline at end of file +# ) diff --git a/litellm/tests/test_custom_callback_router.py b/litellm/tests/test_custom_callback_router.py index 43d532521..ac8b2fa10 100644 --- a/litellm/tests/test_custom_callback_router.py +++ b/litellm/tests/test_custom_callback_router.py @@ -3,7 +3,8 @@ import sys, os, time, inspect, asyncio, traceback from datetime import datetime import pytest -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) from typing import Optional, Literal, List from litellm import Router, Cache import litellm @@ -14,206 +15,274 @@ from litellm.integrations.custom_logger import CustomLogger ## 2: Post-API-Call ## 3: On LiteLLM Call success ## 4: On LiteLLM Call failure -## fallbacks -## retries +## fallbacks +## retries -# Test cases -## 1. Simple Azure OpenAI acompletion + streaming call -## 2. Simple Azure OpenAI aembedding call +# Test cases +## 1. Simple Azure OpenAI acompletion + streaming call +## 2. Simple Azure OpenAI aembedding call ## 3. Azure OpenAI acompletion + streaming call with retries ## 4. Azure OpenAI aembedding call with retries ## 5. Azure OpenAI acompletion + streaming call with fallbacks ## 6. Azure OpenAI aembedding call with fallbacks # Test interfaces -## 1. router.completion() + router.embeddings() -## 2. proxy.completions + proxy.embeddings +## 1. router.completion() + router.embeddings() +## 2. proxy.completions + proxy.embeddings -class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + +class CompletionCustomHandler( + CustomLogger +): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class """ - The set of expected inputs to a custom handler for a + The set of expected inputs to a custom handler for a """ + # Class variables or attributes def __init__(self): self.errors = [] - self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = [] + self.states: Optional[ + List[ + Literal[ + "sync_pre_api_call", + "async_pre_api_call", + "post_api_call", + "sync_stream", + "async_stream", + "sync_success", + "async_success", + "sync_failure", + "async_failure", + ] + ] + ] = [] - def log_pre_api_call(self, model, messages, kwargs): - try: - print(f'received kwargs in pre-input: {kwargs}') + def log_pre_api_call(self, model, messages, kwargs): + try: + print(f"received kwargs in pre-input: {kwargs}") self.states.append("sync_pre_api_call") ## MODEL assert isinstance(model, str) ## MESSAGES assert isinstance(messages, list) ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) ### ROUTER-SPECIFIC KWARGS assert isinstance(kwargs["litellm_params"]["metadata"], dict) assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) assert isinstance(kwargs["litellm_params"]["model_info"], dict) assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) - assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None))) - assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None))) + assert isinstance( + kwargs["litellm_params"]["proxy_server_request"], (str, type(None)) + ) + assert isinstance( + kwargs["litellm_params"]["preset_cache_key"], (str, type(None)) + ) assert isinstance(kwargs["litellm_params"]["stream_response"], dict) - except Exception as e: + except Exception as e: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) - def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): try: self.states.append("post_api_call") - ## START TIME + ## START TIME assert isinstance(start_time, datetime) - ## END TIME + ## END TIME assert end_time == None - ## RESPONSE OBJECT + ## RESPONSE OBJECT assert response_obj == None - ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert isinstance(kwargs['input'], (list, dict, str)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response']) - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) + ## KWARGS + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert isinstance(kwargs["input"], (list, dict, str)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert ( + isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + or inspect.iscoroutine(kwargs["original_response"]) + or inspect.isasyncgen(kwargs["original_response"]) + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) ### ROUTER-SPECIFIC KWARGS assert isinstance(kwargs["litellm_params"]["metadata"], dict) assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) assert isinstance(kwargs["litellm_params"]["model_info"], dict) assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) - assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None))) - assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None))) + assert isinstance( + kwargs["litellm_params"]["proxy_server_request"], (str, type(None)) + ) + assert isinstance( + kwargs["litellm_params"]["preset_cache_key"], (str, type(None)) + ) assert isinstance(kwargs["litellm_params"]["stream_response"], dict) - except: + except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) - + async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): try: self.states.append("async_stream") - ## START TIME + ## START TIME assert isinstance(start_time, datetime) - ## END TIME + ## END TIME assert isinstance(end_time, datetime) - ## RESPONSE OBJECT + ## RESPONSE OBJECT assert isinstance(response_obj, litellm.ModelResponse) ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) - except: + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) and isinstance( + kwargs["messages"][0], dict + ) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert ( + isinstance(kwargs["input"], list) + and isinstance(kwargs["input"][0], dict) + ) or isinstance(kwargs["input"], (dict, str)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert ( + isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + or inspect.isasyncgen(kwargs["original_response"]) + or inspect.iscoroutine(kwargs["original_response"]) + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) + except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) - def log_success_event(self, kwargs, response_obj, start_time, end_time): + def log_success_event(self, kwargs, response_obj, start_time, end_time): try: self.states.append("sync_success") - ## START TIME + ## START TIME assert isinstance(start_time, datetime) - ## END TIME + ## END TIME assert isinstance(end_time, datetime) - ## RESPONSE OBJECT + ## RESPONSE OBJECT assert isinstance(response_obj, litellm.ModelResponse) ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) and isinstance( + kwargs["messages"][0], dict + ) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert ( + isinstance(kwargs["input"], list) + and isinstance(kwargs["input"][0], dict) + ) or isinstance(kwargs["input"], (dict, str)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) - def log_failure_event(self, kwargs, response_obj, start_time, end_time): + def log_failure_event(self, kwargs, response_obj, start_time, end_time): try: self.states.append("sync_failure") - ## START TIME + ## START TIME assert isinstance(start_time, datetime) - ## END TIME + ## END TIME assert isinstance(end_time, datetime) - ## RESPONSE OBJECT + ## RESPONSE OBJECT assert response_obj == None ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) - except: + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) and isinstance( + kwargs["messages"][0], dict + ) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert ( + isinstance(kwargs["input"], list) + and isinstance(kwargs["input"][0], dict) + ) or isinstance(kwargs["input"], (dict, str)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert ( + isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + or kwargs["original_response"] == None + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) + except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) - + async def async_log_pre_api_call(self, model, messages, kwargs): - try: + try: """ - No-op. - Not implemented yet. + No-op. + Not implemented yet. """ pass - except Exception as e: + except Exception as e: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - try: + try: self.states.append("async_success") - ## START TIME + ## START TIME assert isinstance(start_time, datetime) - ## END TIME + ## END TIME assert isinstance(end_time, datetime) - ## RESPONSE OBJECT - assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)) + ## RESPONSE OBJECT + assert isinstance( + response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse) + ) ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert isinstance(kwargs['input'], (list, dict, str)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert isinstance(kwargs["input"], (list, dict, str)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert ( + isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + or inspect.isasyncgen(kwargs["original_response"]) + or inspect.iscoroutine(kwargs["original_response"]) + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) ### ROUTER-SPECIFIC KWARGS assert isinstance(kwargs["litellm_params"]["metadata"], dict) @@ -221,10 +290,14 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) assert isinstance(kwargs["litellm_params"]["model_info"], dict) assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) - assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None))) - assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None))) + assert isinstance( + kwargs["litellm_params"]["proxy_server_request"], (str, type(None)) + ) + assert isinstance( + kwargs["litellm_params"]["preset_cache_key"], (str, type(None)) + ) assert isinstance(kwargs["litellm_params"]["stream_response"], dict) - except: + except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) @@ -232,257 +305,281 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse try: print(f"received original response: {kwargs['original_response']}") self.states.append("async_failure") - ## START TIME + ## START TIME assert isinstance(start_time, datetime) - ## END TIME + ## END TIME assert isinstance(end_time, datetime) - ## RESPONSE OBJECT + ## RESPONSE OBJECT assert response_obj == None ## KWARGS - assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) - assert isinstance(kwargs['optional_params'], dict) - assert isinstance(kwargs['litellm_params'], dict) - assert isinstance(kwargs['start_time'], (datetime, type(None))) - assert isinstance(kwargs['stream'], bool) - assert isinstance(kwargs['user'], (str, type(None))) - assert isinstance(kwargs['input'], (list, str, dict)) - assert isinstance(kwargs['api_key'], (str, type(None))) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None - assert isinstance(kwargs['additional_args'], (dict, type(None))) - assert isinstance(kwargs['log_event_type'], str) - except: + assert isinstance(kwargs["model"], str) + assert isinstance(kwargs["messages"], list) + assert isinstance(kwargs["optional_params"], dict) + assert isinstance(kwargs["litellm_params"], dict) + assert isinstance(kwargs["start_time"], (datetime, type(None))) + assert isinstance(kwargs["stream"], bool) + assert isinstance(kwargs["user"], (str, type(None))) + assert isinstance(kwargs["input"], (list, str, dict)) + assert isinstance(kwargs["api_key"], (str, type(None))) + assert ( + isinstance( + kwargs["original_response"], (str, litellm.CustomStreamWrapper) + ) + or inspect.isasyncgen(kwargs["original_response"]) + or inspect.iscoroutine(kwargs["original_response"]) + or kwargs["original_response"] == None + ) + assert isinstance(kwargs["additional_args"], (dict, type(None))) + assert isinstance(kwargs["log_event_type"], str) + except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) -# Simple Azure OpenAI call + +# Simple Azure OpenAI call ## COMPLETION @pytest.mark.asyncio async def test_async_chat_azure(): - try: + try: customHandler_completion_azure_router = CompletionCustomHandler() customHandler_streaming_azure_router = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler() litellm.callbacks = [customHandler_completion_azure_router] model_list = [ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - ] - router = Router(model_list=model_list) # type: ignore - response = await router.acompletion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm openai" - }]) + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + ] + router = Router(model_list=model_list) # type: ignore + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + ) await asyncio.sleep(2) assert len(customHandler_completion_azure_router.errors) == 0 - assert len(customHandler_completion_azure_router.states) == 3 # pre, post, success - # streaming + assert ( + len(customHandler_completion_azure_router.states) == 3 + ) # pre, post, success + # streaming litellm.callbacks = [customHandler_streaming_azure_router] - router2 = Router(model_list=model_list) # type: ignore - response = await router2.acompletion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm openai" - }], - stream=True) - async for chunk in response: + router2 = Router(model_list=model_list) # type: ignore + response = await router2.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + stream=True, + ) + async for chunk in response: print(f"async azure router chunk: {chunk}") continue await asyncio.sleep(1) print(f"customHandler.states: {customHandler_streaming_azure_router.states}") assert len(customHandler_streaming_azure_router.errors) == 0 - assert len(customHandler_streaming_azure_router.states) >= 4 # pre, post, stream (multiple times), success - # failure + assert ( + len(customHandler_streaming_azure_router.states) >= 4 + ) # pre, post, stream (multiple times), success + # failure model_list = [ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": "my-bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - ] + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": "my-bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + ] litellm.callbacks = [customHandler_failure] - router3 = Router(model_list=model_list) # type: ignore - try: - response = await router3.acompletion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm openai" - }]) + router3 = Router(model_list=model_list) # type: ignore + try: + response = await router3.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + ) print(f"response in router3 acompletion: {response}") except: pass await asyncio.sleep(1) print(f"customHandler.states: {customHandler_failure.states}") assert len(customHandler_failure.errors) == 0 - assert len(customHandler_failure.states) == 3 # pre, post, failure + assert len(customHandler_failure.states) == 3 # pre, post, failure assert "async_failure" in customHandler_failure.states - except Exception as e: + except Exception as e: print(f"Assertion Error: {traceback.format_exc()}") pytest.fail(f"An exception occurred - {str(e)}") + + # asyncio.run(test_async_chat_azure()) ## EMBEDDING @pytest.mark.asyncio async def test_async_embedding_azure(): - try: + try: customHandler = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler() litellm.callbacks = [customHandler] model_list = [ - { - "model_name": "azure-embedding-model", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/azure-embedding-model", - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - ] - router = Router(model_list=model_list) # type: ignore - response = await router.aembedding(model="azure-embedding-model", - input=["hello from litellm!"]) + { + "model_name": "azure-embedding-model", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/azure-embedding-model", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + ] + router = Router(model_list=model_list) # type: ignore + response = await router.aembedding( + model="azure-embedding-model", input=["hello from litellm!"] + ) await asyncio.sleep(2) assert len(customHandler.errors) == 0 - assert len(customHandler.states) == 3 # pre, post, success - # failure + assert len(customHandler.states) == 3 # pre, post, success + # failure model_list = [ - { - "model_name": "azure-embedding-model", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/azure-embedding-model", - "api_key": "my-bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - ] + { + "model_name": "azure-embedding-model", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/azure-embedding-model", + "api_key": "my-bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + ] litellm.callbacks = [customHandler_failure] - router3 = Router(model_list=model_list) # type: ignore - try: - response = await router3.aembedding(model="azure-embedding-model", - input=["hello from litellm!"]) + router3 = Router(model_list=model_list) # type: ignore + try: + response = await router3.aembedding( + model="azure-embedding-model", input=["hello from litellm!"] + ) print(f"response in router3 aembedding: {response}") except: pass await asyncio.sleep(1) print(f"customHandler.states: {customHandler_failure.states}") assert len(customHandler_failure.errors) == 0 - assert len(customHandler_failure.states) == 3 # pre, post, failure + assert len(customHandler_failure.states) == 3 # pre, post, failure assert "async_failure" in customHandler_failure.states - except Exception as e: + except Exception as e: print(f"Assertion Error: {traceback.format_exc()}") pytest.fail(f"An exception occurred - {str(e)}") + + # asyncio.run(test_async_embedding_azure()) # Azure OpenAI call w/ Fallbacks ## COMPLETION @pytest.mark.asyncio -async def test_async_chat_azure_with_fallbacks(): - try: +async def test_async_chat_azure_with_fallbacks(): + try: customHandler_fallbacks = CompletionCustomHandler() litellm.callbacks = [customHandler_fallbacks] - # with fallbacks + # with fallbacks model_list = [ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": "my-bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "gpt-3.5-turbo-16k", - "litellm_params": { - "model": "gpt-3.5-turbo-16k", - }, - "tpm": 240000, - "rpm": 1800 - } - ] - router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore - response = await router.acompletion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm openai" - }]) + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": "my-bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo-16k", + "litellm_params": { + "model": "gpt-3.5-turbo-16k", + }, + "tpm": 240000, + "rpm": 1800, + }, + ] + router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + ) await asyncio.sleep(2) print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}") assert len(customHandler_fallbacks.errors) == 0 - assert len(customHandler_fallbacks.states) == 6 # pre, post, failure, pre, post, success + assert ( + len(customHandler_fallbacks.states) == 6 + ) # pre, post, failure, pre, post, success litellm.callbacks = [] - except Exception as e: + except Exception as e: print(f"Assertion Error: {traceback.format_exc()}") pytest.fail(f"An exception occurred - {str(e)}") + + # asyncio.run(test_async_chat_azure_with_fallbacks()) -# CACHING + +# CACHING ## Test Azure - completion, embedding @pytest.mark.asyncio async def test_async_completion_azure_caching(): customHandler_caching = CompletionCustomHandler() - litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) litellm.callbacks = [customHandler_caching] unique_time = time.time() model_list = [ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "gpt-3.5-turbo-16k", - "litellm_params": { - "model": "gpt-3.5-turbo-16k", - }, - "tpm": 240000, - "rpm": 1800 - } - ] - router = Router(model_list=model_list) # type: ignore - response1 = await router.acompletion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": f"Hi 👋 - i'm async azure {unique_time}" - }], - caching=True) + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo-16k", + "litellm_params": { + "model": "gpt-3.5-turbo-16k", + }, + "tpm": 240000, + "rpm": 1800, + }, + ] + router = Router(model_list=model_list) # type: ignore + response1 = await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], + caching=True, + ) await asyncio.sleep(1) print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}") - response2 = await router.acompletion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": f"Hi 👋 - i'm async azure {unique_time}" - }], - caching=True) - await asyncio.sleep(1) # success callbacks are done in parallel - print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}") + response2 = await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], + caching=True, + ) + await asyncio.sleep(1) # success callbacks are done in parallel + print( + f"customHandler_caching.states post-cache hit: {customHandler_caching.states}" + ) assert len(customHandler_caching.errors) == 0 - assert len(customHandler_caching.states) == 4 # pre, post, success, success + assert len(customHandler_caching.states) == 4 # pre, post, success, success diff --git a/litellm/tests/test_dynamodb_logs.py b/litellm/tests/test_dynamodb_logs.py index 6e40c9512..a6bedd0cf 100644 --- a/litellm/tests/test_dynamodb_logs.py +++ b/litellm/tests/test_dynamodb_logs.py @@ -1,12 +1,14 @@ import sys import os import io, asyncio + # import logging # logging.basicConfig(level=logging.DEBUG) -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) from litellm import completion import litellm + litellm.num_retries = 3 import time, random @@ -29,11 +31,14 @@ def pre_request(): import re -def verify_log_file(log_file_path): - with open(log_file_path, 'r') as log_file: + +def verify_log_file(log_file_path): + with open(log_file_path, "r") as log_file: log_content = log_file.read() - print(f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content) + print( + f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content + ) # Define the pattern to search for in the log file pattern = r"Response from DynamoDB:{.*?}" @@ -50,17 +55,21 @@ def verify_log_file(log_file_path): print(f"Total occurrences of specified response: {len(matches)}") # Count the occurrences of successful responses (status code 200 or 201) - success_count = sum(1 for match in matches if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match) + success_count = sum( + 1 + for match in matches + if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match + ) # Print the count of successful responses print(f"Count of successful responses from DynamoDB: {success_count}") - assert success_count == 3 # Expect 3 success logs from dynamoDB + assert success_count == 3 # Expect 3 success logs from dynamoDB -def test_dynamo_logging(): +def test_dynamo_logging(): # all dynamodb requests need to be in one test function # since we are modifying stdout, and pytests runs tests in parallel - try: + try: # pre # redirect stdout to log_file @@ -69,44 +78,44 @@ def test_dynamo_logging(): litellm.set_verbose = True original_stdout, log_file, file_name = pre_request() - print("Testing async dynamoDB logging") + async def _test(): return await litellm.acompletion( model="gpt-3.5-turbo", - messages=[{"role": "user", "content":"This is a test"}], + messages=[{"role": "user", "content": "This is a test"}], max_tokens=100, temperature=0.7, - user = "ishaan-2" + user="ishaan-2", ) + response = asyncio.run(_test()) print(f"response: {response}") - - # streaming + async + # streaming + async async def _test2(): - response = await litellm.acompletion( + response = await litellm.acompletion( model="gpt-3.5-turbo", - messages=[{"role": "user", "content":"This is a test"}], + messages=[{"role": "user", "content": "This is a test"}], max_tokens=10, temperature=0.7, - user = "ishaan-2", - stream=True + user="ishaan-2", + stream=True, ) async for chunk in response: pass + asyncio.run(_test2()) # aembedding() async def _test3(): return await litellm.aembedding( - model="text-embedding-ada-002", - input = ["hi"], - user = "ishaan-2" + model="text-embedding-ada-002", input=["hi"], user="ishaan-2" ) + response = asyncio.run(_test3()) time.sleep(1) - except Exception as e: + except Exception as e: pytest.fail(f"An exception occurred - {e}") finally: # post, close log file and verify @@ -117,4 +126,5 @@ def test_dynamo_logging(): verify_log_file(file_name) print("Passed! Testing async dynamoDB logging") + # test_dynamo_logging_async() diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 9a2a5951a..2ef068e28 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -14,39 +14,49 @@ from litellm import embedding, completion litellm.set_verbose = False + def test_openai_embedding(): try: - litellm.set_verbose=True + litellm.set_verbose = True response = embedding( - model="text-embedding-ada-002", - input=["good morning from litellm", "this is another item"], - metadata = {"anything": "good day"} + model="text-embedding-ada-002", + input=["good morning from litellm", "this is another item"], + metadata={"anything": "good day"}, ) litellm_response = dict(response) litellm_response_keys = set(litellm_response.keys()) - litellm_response_keys.discard('_response_ms') + litellm_response_keys.discard("_response_ms") print(litellm_response_keys) print("LiteLLM Response\n") # print(litellm_response) - - # same request with OpenAI 1.0+ + + # same request with OpenAI 1.0+ import openai - client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY']) + + client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) response = client.embeddings.create( - model="text-embedding-ada-002", input=["good morning from litellm", "this is another item"] + model="text-embedding-ada-002", + input=["good morning from litellm", "this is another item"], ) response = dict(response) openai_response_keys = set(response.keys()) print(openai_response_keys) - assert litellm_response_keys == openai_response_keys # ENSURE the Keys in litellm response is exactly what the openai package returns - assert len(litellm_response["data"]) == 2 # expect two embedding responses from litellm_response since input had two + assert ( + litellm_response_keys == openai_response_keys + ) # ENSURE the Keys in litellm response is exactly what the openai package returns + assert ( + len(litellm_response["data"]) == 2 + ) # expect two embedding responses from litellm_response since input had two print(openai_response_keys) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_openai_embedding() + def test_openai_azure_embedding_simple(): try: response = embedding( @@ -55,12 +65,15 @@ def test_openai_azure_embedding_simple(): ) print(response) response_keys = set(dict(response).keys()) - response_keys.discard('_response_ms') - assert set(["usage", "model", "object", "data"]) == set(response_keys) #assert litellm response has expected keys from OpenAI embedding response + response_keys.discard("_response_ms") + assert set(["usage", "model", "object", "data"]) == set( + response_keys + ) # assert litellm response has expected keys from OpenAI embedding response except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_openai_azure_embedding_simple() @@ -69,41 +82,50 @@ def test_openai_azure_embedding_timeouts(): response = embedding( model="azure/azure-embedding-model", input=["good morning from litellm"], - timeout=0.00001 + timeout=0.00001, ) print(response) except openai.APITimeoutError: print("Good job got timeout error!") pass except Exception as e: - pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}") + pytest.fail( + f"Expected timeout error, did not get the correct error. Instead got {e}" + ) + # test_openai_azure_embedding_timeouts() + def test_openai_embedding_timeouts(): try: response = embedding( model="text-embedding-ada-002", input=["good morning from litellm"], - timeout=0.00001 + timeout=0.00001, ) print(response) except openai.APITimeoutError: print("Good job got OpenAI timeout error!") pass except Exception as e: - pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}") + pytest.fail( + f"Expected timeout error, did not get the correct error. Instead got {e}" + ) + + # test_openai_embedding_timeouts() + def test_openai_azure_embedding(): try: - api_key = os.environ['AZURE_API_KEY'] - api_base = os.environ['AZURE_API_BASE'] - api_version = os.environ['AZURE_API_VERSION'] + api_key = os.environ["AZURE_API_KEY"] + api_base = os.environ["AZURE_API_BASE"] + api_version = os.environ["AZURE_API_VERSION"] - os.environ['AZURE_API_VERSION'] = "" - os.environ['AZURE_API_BASE'] = "" - os.environ['AZURE_API_KEY'] = "" + os.environ["AZURE_API_VERSION"] = "" + os.environ["AZURE_API_BASE"] = "" + os.environ["AZURE_API_KEY"] = "" response = embedding( model="azure/azure-embedding-model", @@ -114,137 +136,179 @@ def test_openai_azure_embedding(): ) print(response) - - os.environ['AZURE_API_VERSION'] = api_version - os.environ['AZURE_API_BASE'] = api_base - os.environ['AZURE_API_KEY'] = api_key + os.environ["AZURE_API_VERSION"] = api_version + os.environ["AZURE_API_BASE"] = api_base + os.environ["AZURE_API_KEY"] = api_key except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_openai_azure_embedding() # test_openai_embedding() + def test_cohere_embedding(): try: # litellm.set_verbose=True response = embedding( - model="embed-english-v2.0", input=["good morning from litellm", "this is another item"] + model="embed-english-v2.0", + input=["good morning from litellm", "this is another item"], ) print(f"response:", response) except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_cohere_embedding() + def test_cohere_embedding3(): try: - litellm.set_verbose=True + litellm.set_verbose = True response = embedding( - model="embed-english-v3.0", - input=["good morning from litellm", "this is another item"], + model="embed-english-v3.0", + input=["good morning from litellm", "this is another item"], ) print(f"response:", response) except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_cohere_embedding3() + def test_bedrock_embedding_titan(): try: - litellm.set_verbose=True + litellm.set_verbose = True response = embedding( - model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data", - "lets test a second string for good measure"] + model="amazon.titan-embed-text-v1", + input=[ + "good morning from litellm, attempting to embed data", + "lets test a second string for good measure", + ], ) print(f"response:", response) - assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list" - print(f"type of first embedding:", type(response['data'][0]['embedding'][0])) - assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats" + assert isinstance( + response["data"][0]["embedding"], list + ), "Expected response to be a list" + print(f"type of first embedding:", type(response["data"][0]["embedding"][0])) + assert all( + isinstance(x, float) for x in response["data"][0]["embedding"] + ), "Expected response to be a list of floats" except Exception as e: pytest.fail(f"Error occurred: {e}") + + test_bedrock_embedding_titan() + def test_bedrock_embedding_cohere(): try: - litellm.set_verbose=False + litellm.set_verbose = False response = embedding( - model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"], - aws_region_name="os.environ/AWS_REGION_NAME_2" + model="cohere.embed-multilingual-v3", + input=[ + "good morning from litellm, attempting to embed data", + "lets test a second string for good measure", + ], + aws_region_name="os.environ/AWS_REGION_NAME_2", ) - assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list" - print(f"type of first embedding:", type(response['data'][0]['embedding'][0])) - assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats" + assert isinstance( + response["data"][0]["embedding"], list + ), "Expected response to be a list" + print(f"type of first embedding:", type(response["data"][0]["embedding"][0])) + assert all( + isinstance(x, float) for x in response["data"][0]["embedding"] + ), "Expected response to be a list of floats" # print(f"response:", response) except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_bedrock_embedding_cohere() + # comment out hf tests - since hf endpoints are unstable def test_hf_embedding(): try: # huggingface/microsoft/codebert-base # huggingface/facebook/bart-large response = embedding( - model="huggingface/sentence-transformers/all-MiniLM-L6-v2", input=["good morning from litellm", "this is another item"] + model="huggingface/sentence-transformers/all-MiniLM-L6-v2", + input=["good morning from litellm", "this is another item"], ) print(f"response:", response) except Exception as e: # Note: Huggingface inference API is unstable and fails with "model loading errors all the time" pass + + # test_hf_embedding() + # test async embeddings def test_aembedding(): try: import asyncio + async def embedding_call(): try: response = await litellm.aembedding( - model="text-embedding-ada-002", - input=["good morning from litellm", "this is another item"] + model="text-embedding-ada-002", + input=["good morning from litellm", "this is another item"], ) print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + asyncio.run(embedding_call()) except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_aembedding() def test_aembedding_azure(): try: import asyncio + async def embedding_call(): try: response = await litellm.aembedding( - model="azure/azure-embedding-model", - input=["good morning from litellm", "this is another item"] + model="azure/azure-embedding-model", + input=["good morning from litellm", "this is another item"], ) print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + asyncio.run(embedding_call()) except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_aembedding_azure() -def test_sagemaker_embeddings(): - try: - response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"]) + +def test_sagemaker_embeddings(): + try: + response = litellm.embedding( + model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", + input=["good morning from litellm", "this is another item"], + ) print(f"response: {response}") - except Exception as e: + except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_sagemaker_embeddings() # def local_proxy_embeddings(): # litellm.set_verbose=True # response = embedding( -# model="openai/custom_embedding", +# model="openai/custom_embedding", # input=["good morning from litellm"], # api_base="http://0.0.0.0:8000/" # ) diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index f6a0ba25b..1cb599206 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -11,17 +11,18 @@ import litellm from litellm import ( embedding, completion, -# AuthenticationError, + # AuthenticationError, ContextWindowExceededError, -# RateLimitError, -# ServiceUnavailableError, -# OpenAIError, + # RateLimitError, + # ServiceUnavailableError, + # OpenAIError, ) from concurrent.futures import ThreadPoolExecutor import pytest + litellm.vertex_project = "pathrise-convert-1606954137718" litellm.vertex_location = "us-central1" -litellm.num_retries=0 +litellm.num_retries = 0 # litellm.failure_callback = ["sentry"] #### What this tests #### @@ -36,7 +37,8 @@ litellm.num_retries=0 models = ["command-nightly"] -# Test 1: Context Window Errors + +# Test 1: Context Window Errors @pytest.mark.parametrize("model", models) def test_context_window(model): print("Testing context window error") @@ -52,17 +54,27 @@ def test_context_window(model): print(f"Worked!") except RateLimitError: print("RateLimited!") - except Exception as e: + except Exception as e: print(f"{e}") pytest.fail(f"An error occcurred - {e}") - + + @pytest.mark.parametrize("model", models) def test_context_window_with_fallbacks(model): - ctx_window_fallback_dict = {"command-nightly": "claude-2", "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k", "azure/chatgpt-v-2": "gpt-3.5-turbo-16k"} + ctx_window_fallback_dict = { + "command-nightly": "claude-2", + "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k", + "azure/chatgpt-v-2": "gpt-3.5-turbo-16k", + } sample_text = "how does a court case get to the Supreme Court?" * 1000 messages = [{"content": sample_text, "role": "user"}] - completion(model=model, messages=messages, context_window_fallback_dict=ctx_window_fallback_dict) + completion( + model=model, + messages=messages, + context_window_fallback_dict=ctx_window_fallback_dict, + ) + # for model in litellm.models_by_provider["bedrock"]: # test_context_window(model=model) @@ -98,7 +110,9 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th os.environ["AI21_API_KEY"] = "bad-key" elif "togethercomputer" in model: temporary_key = os.environ["TOGETHERAI_API_KEY"] - os.environ["TOGETHERAI_API_KEY"] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a" + os.environ[ + "TOGETHERAI_API_KEY" + ] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a" elif model in litellm.openrouter_models: temporary_key = os.environ["OPENROUTER_API_KEY"] os.environ["OPENROUTER_API_KEY"] = "bad-key" @@ -115,9 +129,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th temporary_key = os.environ["REPLICATE_API_KEY"] os.environ["REPLICATE_API_KEY"] = "bad-key" print(f"model: {model}") - response = completion( - model=model, messages=messages - ) + response = completion(model=model, messages=messages) print(f"response: {response}") except AuthenticationError as e: print(f"AuthenticationError Caught Exception - {str(e)}") @@ -148,23 +160,25 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th os.environ["REPLICATE_API_KEY"] = temporary_key elif "j2" in model: os.environ["AI21_API_KEY"] = temporary_key - elif ("togethercomputer" in model): + elif "togethercomputer" in model: os.environ["TOGETHERAI_API_KEY"] = temporary_key elif model in litellm.aleph_alpha_models: os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key elif model in litellm.nlp_cloud_models: os.environ["NLP_CLOUD_API_KEY"] = temporary_key - elif "bedrock" in model: + elif "bedrock" in model: os.environ["AWS_ACCESS_KEY_ID"] = temporary_aws_access_key os.environ["AWS_REGION_NAME"] = temporary_aws_region_name os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key return + # for model in litellm.models_by_provider["bedrock"]: # invalid_auth(model=model) # invalid_auth(model="command-nightly") -# Test 3: Invalid Request Error + +# Test 3: Invalid Request Error @pytest.mark.parametrize("model", models) def test_invalid_request_error(model): messages = [{"content": "hey, how's it going?", "role": "user"}] @@ -173,23 +187,18 @@ def test_invalid_request_error(model): completion(model=model, messages=messages, max_tokens="hello world") - def test_completion_azure_exception(): try: import openai + print("azure gpt-3.5 test\n\n") - litellm.set_verbose=True + litellm.set_verbose = True ## Test azure call old_azure_key = os.environ["AZURE_API_KEY"] os.environ["AZURE_API_KEY"] = "good morning" response = completion( model="azure/chatgpt-v-2", - messages=[ - { - "role": "user", - "content": "hello" - } - ], + messages=[{"role": "user", "content": "hello"}], ) os.environ["AZURE_API_KEY"] = old_azure_key print(f"response: {response}") @@ -199,25 +208,24 @@ def test_completion_azure_exception(): print("good job got the correct error for azure when key not set") except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_azure_exception() + async def asynctest_completion_azure_exception(): try: import openai import litellm + print("azure gpt-3.5 test\n\n") - litellm.set_verbose=True + litellm.set_verbose = True ## Test azure call old_azure_key = os.environ["AZURE_API_KEY"] os.environ["AZURE_API_KEY"] = "good morning" response = await litellm.acompletion( model="azure/chatgpt-v-2", - messages=[ - { - "role": "user", - "content": "hello" - } - ], + messages=[{"role": "user", "content": "hello"}], ) print(f"response: {response}") print(response) @@ -229,6 +237,8 @@ async def asynctest_completion_azure_exception(): print("Got wrong exception") print("exception", e) pytest.fail(f"Error occurred: {e}") + + # import asyncio # asyncio.run( # asynctest_completion_azure_exception() @@ -239,19 +249,17 @@ def asynctest_completion_openai_exception_bad_model(): try: import openai import litellm, asyncio + print("azure exception bad model\n\n") - litellm.set_verbose=True + litellm.set_verbose = True + ## Test azure call async def test(): response = await litellm.acompletion( model="openai/gpt-6", - messages=[ - { - "role": "user", - "content": "hello" - } - ], + messages=[{"role": "user", "content": "hello"}], ) + asyncio.run(test()) except openai.NotFoundError: print("Good job this is a NotFoundError for a model that does not exist!") @@ -261,27 +269,25 @@ def asynctest_completion_openai_exception_bad_model(): assert isinstance(e, openai.BadRequestError) pytest.fail(f"Error occurred: {e}") -# asynctest_completion_openai_exception_bad_model() +# asynctest_completion_openai_exception_bad_model() def asynctest_completion_azure_exception_bad_model(): try: import openai import litellm, asyncio + print("azure exception bad model\n\n") - litellm.set_verbose=True + litellm.set_verbose = True + ## Test azure call async def test(): response = await litellm.acompletion( model="azure/gpt-12", - messages=[ - { - "role": "user", - "content": "hello" - } - ], + messages=[{"role": "user", "content": "hello"}], ) + asyncio.run(test()) except openai.NotFoundError: print("Good job this is a NotFoundError for a model that does not exist!") @@ -290,25 +296,23 @@ def asynctest_completion_azure_exception_bad_model(): print("Raised wrong type of exception", type(e)) pytest.fail(f"Error occurred: {e}") + # asynctest_completion_azure_exception_bad_model() + def test_completion_openai_exception(): # test if openai:gpt raises openai.AuthenticationError try: import openai + print("openai gpt-3.5 test\n\n") - litellm.set_verbose=True + litellm.set_verbose = True ## Test azure call old_azure_key = os.environ["OPENAI_API_KEY"] os.environ["OPENAI_API_KEY"] = "good morning" response = completion( model="gpt-4", - messages=[ - { - "role": "user", - "content": "hello" - } - ], + messages=[{"role": "user", "content": "hello"}], ) print(f"response: {response}") print(response) @@ -317,25 +321,24 @@ def test_completion_openai_exception(): print("OpenAI: good job got the correct error for openai when key not set") except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_openai_exception() + def test_completion_mistral_exception(): # test if mistral/mistral-tiny raises openai.AuthenticationError try: import openai + print("Testing mistral ai exception mapping") - litellm.set_verbose=True + litellm.set_verbose = True ## Test azure call old_azure_key = os.environ["MISTRAL_API_KEY"] os.environ["MISTRAL_API_KEY"] = "good morning" response = completion( model="mistral/mistral-tiny", - messages=[ - { - "role": "user", - "content": "hello" - } - ], + messages=[{"role": "user", "content": "hello"}], ) print(f"response: {response}") print(response) @@ -344,11 +347,11 @@ def test_completion_mistral_exception(): print("good job got the correct error for openai when key not set") except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_mistral_exception() - - # # test_invalid_request_error(model="command-nightly") # # Test 3: Rate Limit Errors # def test_model_call(model): @@ -387,4 +390,4 @@ def test_completion_mistral_exception(): # counts[result] += 1 # accuracy_score = counts[True]/(counts[True] + counts[False]) -# print(f"accuracy_score: {accuracy_score}") \ No newline at end of file +# print(f"accuracy_score: {accuracy_score}") diff --git a/litellm/tests/test_function_calling.py b/litellm/tests/test_function_calling.py index a7f0225d4..2fcbdc946 100644 --- a/litellm/tests/test_function_calling.py +++ b/litellm/tests/test_function_calling.py @@ -13,6 +13,7 @@ import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError import pytest + litellm.num_retries = 0 litellm.cache = None # litellm.set_verbose=True @@ -20,23 +21,32 @@ import json # litellm.success_callback = ["langfuse"] + def get_current_weather(location, unit="fahrenheit"): """Get the current weather in a given location""" if "tokyo" in location.lower(): return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"}) elif "san francisco" in location.lower(): - return json.dumps({"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}) + return json.dumps( + {"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"} + ) elif "paris" in location.lower(): return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"}) else: return json.dumps({"location": location, "temperature": "unknown"}) + # Example dummy function hard coded to return the same weather # In production, this could be your backend API or an external API def test_parallel_function_call(): try: # Step 1: send the conversation and available functions to the model - messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] + messages = [ + { + "role": "user", + "content": "What's the weather like in San Francisco, Tokyo, and Paris?", + } + ] tools = [ { "type": "function", @@ -50,7 +60,10 @@ def test_parallel_function_call(): "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, }, "required": ["location"], }, @@ -69,7 +82,9 @@ def test_parallel_function_call(): print("length of tool calls", len(tool_calls)) print("Expecting there to be 3 tool calls") - assert len(tool_calls) > 1 # this has to call the function for SF, Tokyo and parise + assert ( + len(tool_calls) > 1 + ) # this has to call the function for SF, Tokyo and parise # Step 2: check if the model wanted to call a function if tool_calls: @@ -78,7 +93,9 @@ def test_parallel_function_call(): available_functions = { "get_current_weather": get_current_weather, } # only one function in this example, but you can have multiple - messages.append(response_message) # extend conversation with assistant's reply + messages.append( + response_message + ) # extend conversation with assistant's reply print("Response message\n", response_message) # Step 4: send the info for each function call and function response to the model for tool_call in tool_calls: @@ -99,25 +116,26 @@ def test_parallel_function_call(): ) # extend conversation with function response print(f"messages: {messages}") second_response = litellm.completion( - model="gpt-3.5-turbo-1106", - messages=messages, - temperature=0.2, - seed=22 + model="gpt-3.5-turbo-1106", messages=messages, temperature=0.2, seed=22 ) # get a new response from the model where it can see the function response print("second response\n", second_response) return second_response except Exception as e: pytest.fail(f"Error occurred: {e}") + test_parallel_function_call() - - def test_parallel_function_call_stream(): try: # Step 1: send the conversation and available functions to the model - messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] + messages = [ + { + "role": "user", + "content": "What's the weather like in San Francisco, Tokyo, and Paris?", + } + ] tools = [ { "type": "function", @@ -131,7 +149,10 @@ def test_parallel_function_call_stream(): "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, }, "required": ["location"], }, @@ -144,7 +165,7 @@ def test_parallel_function_call_stream(): tools=tools, stream=True, tool_choice="auto", # auto is default, but we'll be explicit - complete_response = True + complete_response=True, ) print("Response\n", response) # for chunk in response: @@ -154,7 +175,9 @@ def test_parallel_function_call_stream(): print("length of tool calls", len(tool_calls)) print("Expecting there to be 3 tool calls") - assert len(tool_calls) > 1 # this has to call the function for SF, Tokyo and parise + assert ( + len(tool_calls) > 1 + ) # this has to call the function for SF, Tokyo and parise # Step 2: check if the model wanted to call a function if tool_calls: @@ -163,7 +186,9 @@ def test_parallel_function_call_stream(): available_functions = { "get_current_weather": get_current_weather, } # only one function in this example, but you can have multiple - messages.append(response_message) # extend conversation with assistant's reply + messages.append( + response_message + ) # extend conversation with assistant's reply print("Response message\n", response_message) # Step 4: send the info for each function call and function response to the model for tool_call in tool_calls: @@ -184,14 +209,12 @@ def test_parallel_function_call_stream(): ) # extend conversation with function response print(f"messages: {messages}") second_response = litellm.completion( - model="gpt-3.5-turbo-1106", - messages=messages, - temperature=0.2, - seed=22 + model="gpt-3.5-turbo-1106", messages=messages, temperature=0.2, seed=22 ) # get a new response from the model where it can see the function response print("second response\n", second_response) return second_response except Exception as e: pytest.fail(f"Error occurred: {e}") -test_parallel_function_call_stream() \ No newline at end of file + +test_parallel_function_call_stream() diff --git a/litellm/tests/test_get_llm_provider.py b/litellm/tests/test_get_llm_provider.py index afcb3d7af..cf28f2d0d 100644 --- a/litellm/tests/test_get_llm_provider.py +++ b/litellm/tests/test_get_llm_provider.py @@ -11,9 +11,11 @@ sys.path.insert( import pytest import litellm + def test_get_llm_provider(): _, response, _, _ = litellm.get_llm_provider(model="anthropic.claude-v2:1") assert response == "bedrock" -test_get_llm_provider() \ No newline at end of file + +test_get_llm_provider() diff --git a/litellm/tests/test_get_model_cost_map.py b/litellm/tests/test_get_model_cost_map.py index b7763da12..e6181bff5 100644 --- a/litellm/tests/test_get_model_cost_map.py +++ b/litellm/tests/test_get_model_cost_map.py @@ -9,41 +9,60 @@ import litellm from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models import pytest + def test_get_gpt3_tokens(): max_tokens = get_max_tokens("gpt-3.5-turbo") print(max_tokens) - assert max_tokens==4097 + assert max_tokens == 4097 # print(results) + + test_get_gpt3_tokens() + def test_get_palm_tokens(): # # 🦄🦄🦄🦄🦄🦄🦄🦄 max_tokens = get_max_tokens("palm/chat-bison") assert max_tokens == 4096 print(max_tokens) + + test_get_palm_tokens() + def test_zephyr_hf_tokens(): max_tokens = get_max_tokens("huggingface/HuggingFaceH4/zephyr-7b-beta") print(max_tokens) assert max_tokens == 32768 + test_zephyr_hf_tokens() + def test_cost_ft_gpt_35(): try: # this tests if litellm.completion_cost can calculate cost for ft:gpt-3.5-turbo:my-org:custom_suffix:id # it needs to lookup ft:gpt-3.5-turbo in the litellm model_cost map to get the correct cost from litellm import ModelResponse, Choices, Message from litellm.utils import Usage + resp = ModelResponse( - id='chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac', - choices=[Choices(finish_reason=None, index=0, - message=Message(content=' Sure! Here is a short poem about the sky:\n\nA canvas of blue, a', role='assistant'))], - created=1700775391, - model='ft:gpt-3.5-turbo:my-org:custom_suffix:id', - object='chat.completion', system_fingerprint=None, - usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38) + id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", + choices=[ + Choices( + finish_reason=None, + index=0, + message=Message( + content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", + role="assistant", + ), + ) + ], + created=1700775391, + model="ft:gpt-3.5-turbo:my-org:custom_suffix:id", + object="chat.completion", + system_fingerprint=None, + usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38), ) cost = litellm.completion_cost(completion_response=resp) @@ -51,35 +70,58 @@ def test_cost_ft_gpt_35(): input_cost = model_cost["ft:gpt-3.5-turbo"]["input_cost_per_token"] output_cost = model_cost["ft:gpt-3.5-turbo"]["output_cost_per_token"] print(input_cost, output_cost) - expected_cost = (input_cost*resp.usage.prompt_tokens) + (output_cost*resp.usage.completion_tokens) + expected_cost = (input_cost * resp.usage.prompt_tokens) + ( + output_cost * resp.usage.completion_tokens + ) print("\n Excpected cost", expected_cost) assert cost == expected_cost except Exception as e: - pytest.fail(f"Cost Calc failed for ft:gpt-3.5. Expected {expected_cost}, Calculated cost {cost}") + pytest.fail( + f"Cost Calc failed for ft:gpt-3.5. Expected {expected_cost}, Calculated cost {cost}" + ) + + test_cost_ft_gpt_35() + def test_cost_azure_gpt_35(): try: # this tests if litellm.completion_cost can calculate cost for azure/chatgpt-deployment-2 which maps to azure/gpt-3.5-turbo # for this test we check if passing `model` to completion_cost overrides the completion cost from litellm import ModelResponse, Choices, Message from litellm.utils import Usage + resp = ModelResponse( - id='chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac', - choices=[Choices(finish_reason=None, index=0, - message=Message(content=' Sure! Here is a short poem about the sky:\n\nA canvas of blue, a', role='assistant'))], - model='azure/gpt-35-turbo', # azure always has model written like this - usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38) + id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", + choices=[ + Choices( + finish_reason=None, + index=0, + message=Message( + content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", + role="assistant", + ), + ) + ], + model="azure/gpt-35-turbo", # azure always has model written like this + usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38), ) - cost = litellm.completion_cost(completion_response=resp, model="azure/gpt-3.5-turbo") + cost = litellm.completion_cost( + completion_response=resp, model="azure/gpt-3.5-turbo" + ) print("\n Calculated Cost for azure/gpt-3.5-turbo", cost) input_cost = model_cost["azure/gpt-3.5-turbo"]["input_cost_per_token"] output_cost = model_cost["azure/gpt-3.5-turbo"]["output_cost_per_token"] - expected_cost = (input_cost*resp.usage.prompt_tokens) + (output_cost*resp.usage.completion_tokens) + expected_cost = (input_cost * resp.usage.prompt_tokens) + ( + output_cost * resp.usage.completion_tokens + ) print("\n Excpected cost", expected_cost) assert cost == expected_cost except Exception as e: - pytest.fail(f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}") -test_cost_azure_gpt_35() + pytest.fail( + f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}" + ) + +test_cost_azure_gpt_35() diff --git a/litellm/tests/test_get_model_file.py b/litellm/tests/test_get_model_file.py index 820465273..949ff43b8 100644 --- a/litellm/tests/test_get_model_file.py +++ b/litellm/tests/test_get_model_file.py @@ -9,4 +9,4 @@ import pytest try: print(litellm.get_model_cost_map(url="fake-url")) except Exception as e: - pytest.fail(f"An exception occurred: {e}") \ No newline at end of file + pytest.fail(f"An exception occurred: {e}") diff --git a/litellm/tests/test_hf_prompt_templates.py b/litellm/tests/test_hf_prompt_templates.py index c67779f87..ea1e6a7d8 100644 --- a/litellm/tests/test_hf_prompt_templates.py +++ b/litellm/tests/test_hf_prompt_templates.py @@ -11,19 +11,38 @@ sys.path.insert( import pytest from litellm.llms.prompt_templates.factory import prompt_factory -def test_prompt_formatting(): + +def test_prompt_formatting(): try: - prompt = prompt_factory(model="mistralai/Mistral-7B-Instruct-v0.1", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}]) - assert prompt == "[INST] Be a good bot [/INST] [INST] Hello world [/INST]" - except Exception as e: + prompt = prompt_factory( + model="mistralai/Mistral-7B-Instruct-v0.1", + messages=[ + {"role": "system", "content": "Be a good bot"}, + {"role": "user", "content": "Hello world"}, + ], + ) + assert ( + prompt == "[INST] Be a good bot [/INST] [INST] Hello world [/INST]" + ) + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") -def test_prompt_formatting_custom_model(): + +def test_prompt_formatting_custom_model(): try: - prompt = prompt_factory(model="ehartford/dolphin-2.5-mixtral-8x7b", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}], custom_llm_provider="huggingface") + prompt = prompt_factory( + model="ehartford/dolphin-2.5-mixtral-8x7b", + messages=[ + {"role": "system", "content": "Be a good bot"}, + {"role": "user", "content": "Hello world"}, + ], + custom_llm_provider="huggingface", + ) print(f"prompt: {prompt}") - except Exception as e: + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + + # test_prompt_formatting_custom_model() # def logger_fn(user_model_dict): # return @@ -31,7 +50,7 @@ def test_prompt_formatting_custom_model(): # messages=[{"role": "user", "content": "Write me a function to print hello world"}] -# # test if the first-party prompt templates work +# # test if the first-party prompt templates work # def test_huggingface_supported_models(): # model = "huggingface/WizardLM/WizardCoder-Python-34B-V1.0" # response = completion(model=model, messages=messages, max_tokens=256, api_base="https://ji16r2iys9a8rjk2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn) @@ -40,7 +59,7 @@ def test_prompt_formatting_custom_model(): # test_huggingface_supported_models() -# # test if a custom prompt template works +# # test if a custom prompt template works # litellm.register_prompt_template( # model="togethercomputer/LLaMA-2-7B-32K", # roles={"system":"", "assistant":"Assistant:", "user":"User:"}, @@ -53,4 +72,4 @@ def test_prompt_formatting_custom_model(): # print(response['choices'][0]['message']['content']) # return response -# test_huggingface_custom_model() \ No newline at end of file +# test_huggingface_custom_model() diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 06441e3f4..0438659e2 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -1,49 +1,75 @@ -# What this tests? -## This tests the litellm support for the openai /generations endpoint +# What this tests? +## This tests the litellm support for the openai /generations endpoint import sys, os import traceback from dotenv import load_dotenv import logging + logging.basicConfig(level=logging.DEBUG) load_dotenv() import os import asyncio + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest -import litellm +import litellm -def test_image_generation_openai(): + +def test_image_generation_openai(): litellm.set_verbose = True - response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") + response = litellm.image_generation( + prompt="A cute baby sea otter", model="dall-e-3" + ) print(f"response: {response}") assert len(response.data) > 0 + # test_image_generation_openai() -def test_image_generation_azure(): - response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview") + +def test_image_generation_azure(): + response = litellm.image_generation( + prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview" + ) print(f"response: {response}") assert len(response.data) > 0 + + # test_image_generation_azure() -def test_image_generation_azure_dall_e_3(): + +def test_image_generation_azure_dall_e_3(): litellm.set_verbose = True - response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test", api_version="2023-12-01-preview", api_base=os.getenv("AZURE_SWEDEN_API_BASE"), api_key=os.getenv("AZURE_SWEDEN_API_KEY")) + response = litellm.image_generation( + prompt="A cute baby sea otter", + model="azure/dall-e-3-test", + api_version="2023-12-01-preview", + api_base=os.getenv("AZURE_SWEDEN_API_BASE"), + api_key=os.getenv("AZURE_SWEDEN_API_KEY"), + ) print(f"response: {response}") assert len(response.data) > 0 + + # test_image_generation_azure_dall_e_3() @pytest.mark.asyncio -async def test_async_image_generation_openai(): - response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") +async def test_async_image_generation_openai(): + response = litellm.image_generation( + prompt="A cute baby sea otter", model="dall-e-3" + ) print(f"response: {response}") assert len(response.data) > 0 + # asyncio.run(test_async_image_generation_openai()) + @pytest.mark.asyncio -async def test_async_image_generation_azure(): - response = await litellm.aimage_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test") - print(f"response: {response}") \ No newline at end of file +async def test_async_image_generation_azure(): + response = await litellm.aimage_generation( + prompt="A cute baby sea otter", model="azure/dall-e-3-test" + ) + print(f"response: {response}") diff --git a/litellm/tests/test_langchain_ChatLiteLLM.py b/litellm/tests/test_langchain_ChatLiteLLM.py index 46fb33150..27bc209f1 100644 --- a/litellm/tests/test_langchain_ChatLiteLLM.py +++ b/litellm/tests/test_langchain_ChatLiteLLM.py @@ -104,4 +104,3 @@ # # pytest.fail(f"Error occurred: {e}") # # test_openai_with_params() - diff --git a/litellm/tests/test_langsmith.py b/litellm/tests/test_langsmith.py index 0939f4cc3..603a8370d 100644 --- a/litellm/tests/test_langsmith.py +++ b/litellm/tests/test_langsmith.py @@ -2,7 +2,7 @@ import sys import os import io -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) from litellm import completion import litellm @@ -14,59 +14,57 @@ import time def test_langsmith_logging(): try: - response = completion(model="claude-instant-1.2", - messages=[{ - "role": "user", - "content": "what llm are u" - }], - max_tokens=10, - temperature=0.2 - ) + response = completion( + model="claude-instant-1.2", + messages=[{"role": "user", "content": "what llm are u"}], + max_tokens=10, + temperature=0.2, + ) print(response) except Exception as e: print(e) + # test_langsmith_logging() def test_langsmith_logging_with_metadata(): try: - response = completion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "what llm are u" - }], - max_tokens=10, - temperature=0.2, - metadata={ - "run_name": "litellmRUN", - "project_name": "litellm-completion", - } - ) + response = completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "what llm are u"}], + max_tokens=10, + temperature=0.2, + metadata={ + "run_name": "litellmRUN", + "project_name": "litellm-completion", + }, + ) print(response) except Exception as e: print(e) + # test_langsmith_logging_with_metadata() + def test_langsmith_logging_with_streaming_and_metadata(): try: - response = completion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "what llm are u" - }], - max_tokens=10, - temperature=0.2, - metadata={ - "run_name": "litellmRUN", - "project_name": "litellm-completion", - }, - stream=True - ) + response = completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "what llm are u"}], + max_tokens=10, + temperature=0.2, + metadata={ + "run_name": "litellmRUN", + "project_name": "litellm-completion", + }, + stream=True, + ) for chunk in response: continue except Exception as e: print(e) -test_langsmith_logging_with_streaming_and_metadata() \ No newline at end of file + +test_langsmith_logging_with_streaming_and_metadata() diff --git a/litellm/tests/test_least_busy_routing.py b/litellm/tests/test_least_busy_routing.py index 05d3f3ec6..849ff75ac 100644 --- a/litellm/tests/test_least_busy_routing.py +++ b/litellm/tests/test_least_busy_routing.py @@ -2,10 +2,10 @@ # # This tests the router's ability to identify the least busy deployment # # -# # How is this achieved? +# # How is this achieved? # # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"} # # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic} -# # - use litellm.success + failure callbacks to log when a request completed +# # - use litellm.success + failure callbacks to log when a request completed # # - in get_available_deployment, for a given model group name -> pick based on traffic # import sys, os, asyncio, time @@ -48,13 +48,13 @@ # "rpm": 6 # } # }] -# router = Router(model_list=model_list, +# router = Router(model_list=model_list, # routing_strategy="least-busy", # set_verbose=False, # num_retries=3) # type: ignore - + # async def call_azure_completion(): -# try: +# try: # response = await router.acompletion( # model="azure-model", # messages=[ @@ -66,9 +66,9 @@ # ) # print("\n response", response) # return response -# except: +# except: # return None - + # n = 1000 # start_time = time.time() # tasks = [call_azure_completion() for _ in range(n)] @@ -76,4 +76,4 @@ # successful_completions = [c for c in chat_completions if c is not None] # print(n, time.time() - start_time, len(successful_completions)) -# asyncio.run(test_least_busy_routing()) \ No newline at end of file +# asyncio.run(test_least_busy_routing()) diff --git a/litellm/tests/test_litellm_max_budget.py b/litellm/tests/test_litellm_max_budget.py index 0e933c604..9fcddfe32 100644 --- a/litellm/tests/test_litellm_max_budget.py +++ b/litellm/tests/test_litellm_max_budget.py @@ -3,12 +3,12 @@ # # commenting out this test for circle ci, as it causes other tests to fail, since litellm.max_budget would impact other litellm imports # import sys, os, json # import traceback -# import pytest +# import pytest # sys.path.insert( # 0, os.path.abspath("../..") # ) # Adds the parent directory to the system path -# import litellm +# import litellm # # litellm.set_verbose = True # from litellm import completion, BudgetExceededError @@ -18,14 +18,12 @@ # messages = [{"role": "user", "content": "Hey, how's it going"}] # response = completion(model="gpt-4", messages=messages, stream=True) -# for chunk in response: +# for chunk in response: # continue # print(litellm._current_cost) # completion(model="gpt-4", messages=messages, stream=True) # litellm.max_budget = float('inf') -# except BudgetExceededError as e: +# except BudgetExceededError as e: # pass # except Exception as e: # pytest.fail(f"An error occured: {str(e)}") - - diff --git a/litellm/tests/test_loadtest_router.py b/litellm/tests/test_loadtest_router.py index da031be69..5dc60ca3f 100644 --- a/litellm/tests/test_loadtest_router.py +++ b/litellm/tests/test_loadtest_router.py @@ -26,24 +26,24 @@ # async def main(): # # Initialize the Router # model_list= [{ -# "model_name": "gpt-3.5-turbo", +# "model_name": "gpt-3.5-turbo", # "litellm_params": { -# "model": "gpt-3.5-turbo", -# "api_key": os.getenv("OPENAI_API_KEY"), +# "model": "gpt-3.5-turbo", +# "api_key": os.getenv("OPENAI_API_KEY"), # }, # }, { -# "model_name": "gpt-3.5-turbo", +# "model_name": "gpt-3.5-turbo", # "litellm_params": { -# "model": "azure/chatgpt-v-2", -# "api_key": os.getenv("AZURE_API_KEY"), +# "model": "azure/chatgpt-v-2", +# "api_key": os.getenv("AZURE_API_KEY"), # "api_base": os.getenv("AZURE_API_BASE"), # "api_version": os.getenv("AZURE_API_VERSION") # }, # }, { -# "model_name": "gpt-3.5-turbo", +# "model_name": "gpt-3.5-turbo", # "litellm_params": { -# "model": "azure/chatgpt-functioncalling", -# "api_key": os.getenv("AZURE_API_KEY"), +# "model": "azure/chatgpt-functioncalling", +# "api_key": os.getenv("AZURE_API_KEY"), # "api_base": os.getenv("AZURE_API_BASE"), # "api_version": os.getenv("AZURE_API_VERSION") # }, diff --git a/litellm/tests/test_logging.py b/litellm/tests/test_logging.py index d8557d4c9..1a35d8454 100644 --- a/litellm/tests/test_logging.py +++ b/litellm/tests/test_logging.py @@ -39,7 +39,7 @@ # messages = [{"content": user_message, "role": "user"}] # # 1. On Call Success -# # normal completion +# # normal completion # # test on openai completion call # def test_logging_success_completion(): # global score @@ -73,7 +73,7 @@ # # sys.stdout = new_stdout = io.StringIO() # # response = completion(model="claude-instant-1", messages=messages) - + # # # Restore stdout # # sys.stdout = old_stdout # # output = new_stdout.getvalue().strip() @@ -100,9 +100,9 @@ # completion_response, # response from completion # start_time, end_time # start/end time # ): -# if "complete_streaming_response" in kwargs: +# if "complete_streaming_response" in kwargs: # print(f"Complete Streaming Response: {kwargs['complete_streaming_response']}") - + # # Assign the custom callback function # litellm.success_callback = [custom_callback] @@ -111,7 +111,7 @@ # sys.stdout = new_stdout = io.StringIO() # response = completion(model="gpt-3.5-turbo", messages=messages, stream=True) -# for chunk in response: +# for chunk in response: # pass # # Restore stdout @@ -131,7 +131,7 @@ # pytest.fail(f"Error occurred: {e}") # pass -# # test_logging_success_streaming_openai() +# # test_logging_success_streaming_openai() # ## test on non-openai completion call # def test_logging_success_streaming_non_openai(): @@ -144,9 +144,9 @@ # start_time, end_time # start/end time # ): # # print(f"streaming response: {completion_response}") -# if "complete_streaming_response" in kwargs: +# if "complete_streaming_response" in kwargs: # print(f"Complete Streaming Response: {kwargs['complete_streaming_response']}") - + # # Assign the custom callback function # litellm.success_callback = [custom_callback] @@ -155,9 +155,9 @@ # sys.stdout = new_stdout = io.StringIO() # response = completion(model="claude-instant-1", messages=messages, stream=True) -# for idx, chunk in enumerate(response): +# for idx, chunk in enumerate(response): # pass - + # # Restore stdout # sys.stdout = old_stdout # output = new_stdout.getvalue().strip() @@ -175,7 +175,7 @@ # pytest.fail(f"Error occurred: {e}") # pass -# # test_logging_success_streaming_non_openai() +# # test_logging_success_streaming_non_openai() # # embedding # def test_logging_success_embedding_openai(): @@ -202,7 +202,7 @@ # # ## 2. On LiteLLM Call failure # # ## TEST BAD KEY -# # # normal completion +# # # normal completion # # ## test on openai completion call # # try: # # temporary_oai_key = os.environ["OPENAI_API_KEY"] @@ -215,7 +215,7 @@ # # # Redirect stdout # # old_stdout = sys.stdout # # sys.stdout = new_stdout = io.StringIO() - + # # try: # # response = completion(model="gpt-3.5-turbo", messages=messages) # # except AuthenticationError: @@ -229,14 +229,14 @@ # # if "Logging Details Pre-API Call" not in output: # # raise Exception("Required log message not found!") -# # elif "Logging Details Post-API Call" not in output: +# # elif "Logging Details Post-API Call" not in output: # # raise Exception("Required log message not found!") # # elif "Logging Details LiteLLM-Failure Call" not in output: # # raise Exception("Required log message not found!") # # os.environ["OPENAI_API_KEY"] = temporary_oai_key # # os.environ["ANTHROPIC_API_KEY"] = temporary_anthropic_key - + # # score += 1 # # except Exception as e: # # print(f"exception type: {type(e).__name__}") @@ -307,7 +307,7 @@ # # raise Exception("Required log message not found!") # # elif "Logging Details LiteLLM-Failure Call" not in output: # # raise Exception("Required log message not found!") - + # # os.environ["OPENAI_API_KEY"] = temporary_oai_key # # os.environ["ANTHROPIC_API_KEY"] = temporary_anthropic_key # # score += 1 @@ -330,7 +330,7 @@ # # response = completion(model="claude-instant-1", messages=messages) # # except AuthenticationError: # # pass - + # # # Restore stdout # # sys.stdout = old_stdout # # output = new_stdout.getvalue().strip() @@ -379,4 +379,4 @@ # # raise Exception("Required log message not found!") # # except Exception as e: # # print(f"exception type: {type(e).__name__}") -# # pytest.fail(f"Error occurred: {e}") \ No newline at end of file +# # pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_longer_context_fallback.py b/litellm/tests/test_longer_context_fallback.py index 8f8942897..07e9e8cad 100644 --- a/litellm/tests/test_longer_context_fallback.py +++ b/litellm/tests/test_longer_context_fallback.py @@ -4,10 +4,11 @@ import sys, os import traceback import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import litellm from litellm import longer_context_model_fallback_dict -print(longer_context_model_fallback_dict) \ No newline at end of file +print(longer_context_model_fallback_dict) diff --git a/litellm/tests/test_mock_request.py b/litellm/tests/test_mock_request.py index 82eb30926..4a2d661c6 100644 --- a/litellm/tests/test_mock_request.py +++ b/litellm/tests/test_mock_request.py @@ -1,5 +1,5 @@ #### What this tests #### -# This tests mock request calls to litellm +# This tests mock request calls to litellm import sys, os import traceback @@ -9,6 +9,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import litellm + def test_mock_request(): try: model = "gpt-3.5-turbo" @@ -19,18 +20,20 @@ def test_mock_request(): except: traceback.print_exc() + # test_mock_request() def test_streaming_mock_request(): - try: + try: model = "gpt-3.5-turbo" messages = [{"role": "user", "content": "Hey, I'm a mock request"}] response = litellm.mock_completion(model=model, messages=messages, stream=True) - complete_response = "" - for chunk in response: + complete_response = "" + for chunk in response: complete_response += chunk["choices"][0]["delta"]["content"] - if complete_response == "": + if complete_response == "": raise Exception("Empty response received") except: traceback.print_exc() -test_streaming_mock_request() \ No newline at end of file + +test_streaming_mock_request() diff --git a/litellm/tests/test_model_alias_map.py b/litellm/tests/test_model_alias_map.py index b99a626e3..1501f49e4 100644 --- a/litellm/tests/test_model_alias_map.py +++ b/litellm/tests/test_model_alias_map.py @@ -13,9 +13,7 @@ import pytest litellm.set_verbose = True -model_alias_map = { - "good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf" -} +model_alias_map = {"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"} def test_model_alias_map(): @@ -32,6 +30,6 @@ def test_model_alias_map(): assert "Llama-2-7b-chat-hf" in response.model except Exception as e: pytest.fail(f"Error occurred: {e}") - -test_model_alias_map() \ No newline at end of file + +test_model_alias_map() diff --git a/litellm/tests/test_multiple_deployments.py b/litellm/tests/test_multiple_deployments.py index 35161194c..0a79cb518 100644 --- a/litellm/tests/test_multiple_deployments.py +++ b/litellm/tests/test_multiple_deployments.py @@ -11,36 +11,44 @@ import pytest import litellm from litellm import completion -messages=[{"role": "user", "content": "Hey, how's it going?"}] +messages = [{"role": "user", "content": "Hey, how's it going?"}] ## All your mistral deployments ## -model_list = [{ - "model_name": "mistral-7b-instruct", - "litellm_params": { # params for litellm completion/embedding call - "model": "replicate/mistralai/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70", - "api_key": os.getenv("REPLICATE_API_KEY"), - } -}, { - "model_name": "mistral-7b-instruct", - "litellm_params": { # params for litellm completion/embedding call - "model": "together_ai/mistralai/Mistral-7B-Instruct-v0.1", - "api_key": os.getenv("TOGETHERAI_API_KEY"), - } -}, { - "model_name": "mistral-7b-instruct", - "litellm_params": { - "model": "deepinfra/mistralai/Mistral-7B-Instruct-v0.1", - "api_key": os.getenv("DEEPINFRA_API_KEY") - } -}] +model_list = [ + { + "model_name": "mistral-7b-instruct", + "litellm_params": { # params for litellm completion/embedding call + "model": "replicate/mistralai/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70", + "api_key": os.getenv("REPLICATE_API_KEY"), + }, + }, + { + "model_name": "mistral-7b-instruct", + "litellm_params": { # params for litellm completion/embedding call + "model": "together_ai/mistralai/Mistral-7B-Instruct-v0.1", + "api_key": os.getenv("TOGETHERAI_API_KEY"), + }, + }, + { + "model_name": "mistral-7b-instruct", + "litellm_params": { + "model": "deepinfra/mistralai/Mistral-7B-Instruct-v0.1", + "api_key": os.getenv("DEEPINFRA_API_KEY"), + }, + }, +] + def test_multiple_deployments(): - try: - ## LiteLLM completion call ## returns first response - response = completion(model="mistral-7b-instruct", messages=messages, model_list=model_list) + try: + ## LiteLLM completion call ## returns first response + response = completion( + model="mistral-7b-instruct", messages=messages, model_list=model_list + ) print(f"response: {response}") except Exception as e: traceback.print_exc() pytest.fail(f"An exception occurred: {e}") -test_multiple_deployments() \ No newline at end of file + +test_multiple_deployments() diff --git a/litellm/tests/test_ollama.py b/litellm/tests/test_ollama.py index a9f3f5468..82ec16f0e 100644 --- a/litellm/tests/test_ollama.py +++ b/litellm/tests/test_ollama.py @@ -7,7 +7,7 @@ import os, io sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system path import pytest import litellm @@ -15,13 +15,26 @@ import litellm ## for ollama we can't test making the completion call from litellm.utils import get_optional_params, get_llm_provider + def test_get_ollama_params(): try: - converted_params = get_optional_params(custom_llm_provider="ollama", model="llama2", max_tokens=20, temperature=0.5, stream=True) + converted_params = get_optional_params( + custom_llm_provider="ollama", + model="llama2", + max_tokens=20, + temperature=0.5, + stream=True, + ) print("Converted params", converted_params) - assert converted_params == {'num_predict': 20, 'stream': True, 'temperature': 0.5}, f"{converted_params} != {'num_predict': 20, 'stream': True, 'temperature': 0.5}" + assert converted_params == { + "num_predict": 20, + "stream": True, + "temperature": 0.5, + }, f"{converted_params} != {'num_predict': 20, 'stream': True, 'temperature': 0.5}" except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_get_ollama_params() @@ -44,4 +57,4 @@ def test_ollama_json_mode(): assert converted_params == {'temperature': 0.5, 'format': 'json'}, f"{converted_params} != {'temperature': 0.5, 'format': 'json'}" except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_ollama_json_mode() \ No newline at end of file +# test_ollama_json_mode() diff --git a/litellm/tests/test_ollama_local.py b/litellm/tests/test_ollama_local.py index 3eb996970..d4dbc7341 100644 --- a/litellm/tests/test_ollama_local.py +++ b/litellm/tests/test_ollama_local.py @@ -25,7 +25,7 @@ # asyncio.run(test_ollama_aembeddings()) # def test_ollama_streaming(): -# try: +# try: # litellm.set_verbose = False # messages = [ # {"role": "user", "content": "What is the weather like in Boston?"} @@ -50,26 +50,26 @@ # } # } # ] -# response = litellm.completion(model="ollama/mistral", +# response = litellm.completion(model="ollama/mistral", # messages=messages, # functions=functions, # stream=True) -# for chunk in response: +# for chunk in response: # print(f"CHUNK: {chunk}") -# except Exception as e: +# except Exception as e: # print(e) # # test_ollama_streaming() # async def test_async_ollama_streaming(): -# try: +# try: # litellm.set_verbose = False -# response = await litellm.acompletion(model="ollama/mistral-openorca", +# response = await litellm.acompletion(model="ollama/mistral-openorca", # messages=[{"role": "user", "content": "Hey, how's it going?"}], # stream=True) -# async for chunk in response: +# async for chunk in response: # print(f"CHUNK: {chunk}") -# except Exception as e: +# except Exception as e: # print(e) # # asyncio.run(test_async_ollama_streaming()) @@ -78,13 +78,13 @@ # try: # litellm.set_verbose = True # response = completion( -# model="ollama/mistral", -# messages=[{"role": "user", "content": "Hey, how's it going?"}], +# model="ollama/mistral", +# messages=[{"role": "user", "content": "Hey, how's it going?"}], # max_tokens=200, # request_timeout = 10, # stream=True # ) -# for chunk in response: +# for chunk in response: # print(chunk) # print(response) # except Exception as e: @@ -119,13 +119,13 @@ # } # ] # response = completion( -# model="ollama/mistral", +# model="ollama/mistral", # messages=messages, -# functions=functions, +# functions=functions, # max_tokens=200, # request_timeout = 10, # ) -# for chunk in response: +# for chunk in response: # print(chunk) # print(response) # except Exception as e: @@ -159,9 +159,9 @@ # } # ] # response = await litellm.acompletion( -# model="ollama/mistral", +# model="ollama/mistral", # messages=messages, -# functions=functions, +# functions=functions, # max_tokens=200, # request_timeout = 10, # ) @@ -175,8 +175,8 @@ # def test_completion_ollama_with_api_base(): # try: # response = completion( -# model="ollama/llama2", -# messages=messages, +# model="ollama/llama2", +# messages=messages, # api_base="http://localhost:11434" # ) # print(response) @@ -200,8 +200,8 @@ # litellm.set_verbose = True # try: # response = completion( -# model="ollama/llama2", -# messages=messages, +# model="ollama/llama2", +# messages=messages, # stream=True # ) # print(response) @@ -220,20 +220,20 @@ # messages = [{ "content": user_message,"role": "user"}] # try: # response = await litellm.acompletion( -# model="ollama/llama2", -# messages=messages, -# api_base="http://localhost:11434", +# model="ollama/llama2", +# messages=messages, +# api_base="http://localhost:11434", # stream=True # ) # async for chunk in response: # print(chunk['choices'][0]['delta']) - + # print("TEST ASYNC NON Stream") # response = await litellm.acompletion( -# model="ollama/llama2", -# messages=messages, -# api_base="http://localhost:11434", +# model="ollama/llama2", +# messages=messages, +# api_base="http://localhost:11434", # ) # print(response) # except Exception as e: @@ -243,7 +243,6 @@ # # asyncio.run(test_completion_ollama_async_stream()) - # def prepare_messages_for_chat(text: str) -> list: # messages = [ # {"role": "user", "content": text}, @@ -265,7 +264,7 @@ # response = await ask_question() # async for chunk in response: # print(chunk) - + # print("test async completion without streaming") # response = await litellm.acompletion( # model="ollama/llama2", @@ -282,8 +281,8 @@ # messages = [{ "content": user_message,"role": "user"}] # try: # response = completion( -# model="ollama/invalid", -# messages=messages, +# model="ollama/invalid", +# messages=messages, # stream=True # ) # print(response) @@ -302,7 +301,7 @@ # litellm.set_verbose=True # # same params as gpt-4 vision # response = completion( -# model = "ollama/llava", +# model = "ollama/llava", # messages=[ # { # "role": "user", @@ -323,7 +322,7 @@ # ) # print("Response from ollama/llava") # print(response) -# # test_ollama_llava() +# # test_ollama_llava() # # PROCESSED CHUNK PRE CHUNK CREATOR diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index 0ce917afe..784918e88 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -2,26 +2,37 @@ # This tests if get_optional_params works as expected import sys, os, time, inspect, asyncio, traceback import pytest -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) import litellm from litellm.utils import get_optional_params_embeddings -## get_optional_params_embeddings -### Models: OpenAI, Azure, Bedrock -### Scenarios: w/ optional params + litellm.drop_params = True + +## get_optional_params_embeddings +### Models: OpenAI, Azure, Bedrock +### Scenarios: w/ optional params + litellm.drop_params = True + def test_bedrock_optional_params_embeddings(): litellm.drop_params = True - optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="bedrock") + optional_params = get_optional_params_embeddings( + user="John", encoding_format=None, custom_llm_provider="bedrock" + ) assert len(optional_params) == 0 + def test_openai_optional_params_embeddings(): litellm.drop_params = True - optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="openai") + optional_params = get_optional_params_embeddings( + user="John", encoding_format=None, custom_llm_provider="openai" + ) assert len(optional_params) == 1 assert optional_params["user"] == "John" + def test_azure_optional_params_embeddings(): litellm.drop_params = True - optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="azure") + optional_params = get_optional_params_embeddings( + user="John", encoding_format=None, custom_llm_provider="azure" + ) assert len(optional_params) == 1 assert optional_params["user"] == "John" diff --git a/litellm/tests/test_profiling_router.py b/litellm/tests/test_profiling_router.py index 48ed9cb0e..5e1646847 100644 --- a/litellm/tests/test_profiling_router.py +++ b/litellm/tests/test_profiling_router.py @@ -33,7 +33,7 @@ # litellm.telemetry = False -# num_task_cancelled_errors = 0 +# num_task_cancelled_errors = 0 # model_list = [{ # "model_name": "azure-model", @@ -63,23 +63,23 @@ # router = Router(model_list=model_list, set_verbose=False, num_retries=3) -# async def router_completion(): +# async def router_completion(): # global num_task_cancelled_errors, exception_counts -# try: +# try: # messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}] # response = await router.acompletion(model="azure-model", messages=messages) # return response # except asyncio.exceptions.CancelledError: # exception_type = "CancelledError" # exception_counts[exception_type] = exception_counts.get(exception_type, 0) + 1 -# print("Task was cancelled") +# print("Task was cancelled") # num_task_cancelled_errors += 1 # exception_data.append({ # "type": exception_type, # "traceback": None # }) # return None -# except Exception as e: +# except Exception as e: # exception_type = type(e).__name__ # exception_counts[exception_type] = exception_counts.get(exception_type, 0) + 1 # exception_data.append({ @@ -95,12 +95,12 @@ # chat_completions = await asyncio.gather(*tasks) # successful_completions = [c for c in chat_completions if c is not None] # print(n, time.time() - start, len(successful_completions)) - + # # Print exception breakdown # print("Exception Breakdown:") # for exception_type, count in exception_counts.items(): # print(f"{exception_type}: {count}") - + # # Store exception_data in a file # with open('exception_data.txt', 'w') as file: # for data in exception_data: @@ -130,7 +130,7 @@ # # total_successful_requests = 0 # # request_limit = 1000 # # batches = 2 # batches of 1k requests -# # start = time.time() +# # start = time.time() # # tasks = [] # list to hold all tasks # # async def request_loop(): @@ -139,7 +139,7 @@ # # # Make 1,000 requests # # task = asyncio.create_task(make_requests(request_limit)) # # tasks.append(task) - + # # # Introduce a delay to achieve 1,000 requests per second # # await asyncio.sleep(1) @@ -149,4 +149,4 @@ # # print(request_limit*batches, time.time() - start, total_successful_requests) -# # asyncio.run(main()) \ No newline at end of file +# # asyncio.run(main()) diff --git a/litellm/tests/test_prompt_factory.py b/litellm/tests/test_prompt_factory.py index 11ebbb424..5b5383084 100644 --- a/litellm/tests/test_prompt_factory.py +++ b/litellm/tests/test_prompt_factory.py @@ -4,14 +4,18 @@ import sys import os import io -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) # from litellm.llms.prompt_templates.factory import prompt_factory from litellm import completion + def codellama_prompt_format(): model = "huggingface/codellama/CodeLlama-7b-Instruct-hf" - messages = [{"role": "system", "content": "You are a good bot"}, {"role": "user", "content": "Hey, how's it going?"}] + messages = [ + {"role": "system", "content": "You are a good bot"}, + {"role": "user", "content": "Hey, how's it going?"}, + ] expected_response = """[INST] <> You are a good bot <> @@ -20,4 +24,5 @@ You are a good bot response = completion(model=model, messages=messages) print(response) -# codellama_prompt_format() \ No newline at end of file + +# codellama_prompt_format() diff --git a/litellm/tests/test_promptlayer_integration.py b/litellm/tests/test_promptlayer_integration.py index 8c919bb1e..9f0af1af8 100644 --- a/litellm/tests/test_promptlayer_integration.py +++ b/litellm/tests/test_promptlayer_integration.py @@ -2,7 +2,7 @@ import sys import os import io -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) from litellm import completion import litellm @@ -12,7 +12,6 @@ litellm.set_verbose = True import time - # def test_promptlayer_logging(): # try: # # Redirect stdout @@ -46,14 +45,13 @@ def test_promptlayer_logging_with_metadata(): old_stdout = sys.stdout sys.stdout = new_stdout = io.StringIO() - response = completion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm ai21" - }], - temperature=0.2, - max_tokens=20, - metadata={"model": "ai21"}) + response = completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm ai21"}], + temperature=0.2, + max_tokens=20, + metadata={"model": "ai21"}, + ) # Restore stdout time.sleep(1) @@ -66,11 +64,10 @@ def test_promptlayer_logging_with_metadata(): except Exception as e: print(e) + test_promptlayer_logging_with_metadata() - - # def test_chat_openai(): # try: # response = completion(model="replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1", diff --git a/litellm/tests/test_provider_specific_config.py b/litellm/tests/test_provider_specific_config.py index 06282e234..55986ff70 100644 --- a/litellm/tests/test_provider_specific_config.py +++ b/litellm/tests/test_provider_specific_config.py @@ -1,10 +1,11 @@ #### What this tests #### # This tests setting provider specific configs across providers -# There are 2 types of tests - changing config dynamically or by setting class variables +# There are 2 types of tests - changing config dynamically or by setting class variables import sys, os import traceback import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path @@ -12,7 +13,7 @@ import litellm from litellm import completion from litellm import RateLimitError -# Huggingface - Expensive to deploy models and keep them running. Maybe we can try doing this via baseten?? +# Huggingface - Expensive to deploy models and keep them running. Maybe we can try doing this via baseten?? # def hf_test_completion_tgi(): # litellm.HuggingfaceConfig(max_new_tokens=200) # litellm.set_verbose=True @@ -43,7 +44,8 @@ from litellm import RateLimitError # pytest.fail(f"Error occurred: {e}") # hf_test_completion_tgi() -#Anthropic +# Anthropic + def claude_test_completion(): litellm.AnthropicConfig(max_tokens_to_sample=200) @@ -52,8 +54,8 @@ def claude_test_completion(): # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="claude-instant-1", - messages=[{ "content": "Hello, how are you?","role": "user"}], - max_tokens=10 + messages=[{"content": "Hello, how are you?", "role": "user"}], + max_tokens=10, ) # Add any assertions here to check the response print(response_1) @@ -62,7 +64,7 @@ def claude_test_completion(): # USE CONFIG TOKENS response_2 = litellm.completion( model="claude-instant-1", - messages=[{ "content": "Hello, how are you?","role": "user"}], + messages=[{"content": "Hello, how are you?", "role": "user"}], ) # Add any assertions here to check the response print(response_2) @@ -70,20 +72,24 @@ def claude_test_completion(): assert len(response_2_text) > len(response_1_text) - try: - response_3 = litellm.completion(model="claude-instant-1", - messages=[{ "content": "Hello, how are you?","role": "user"}], - n=2) - - except Exception as e: + try: + response_3 = litellm.completion( + model="claude-instant-1", + messages=[{"content": "Hello, how are you?", "role": "user"}], + n=2, + ) + + except Exception as e: print(e) except Exception as e: pytest.fail(f"Error occurred: {e}") + # claude_test_completion() # Replicate + def replicate_test_completion(): litellm.ReplicateConfig(max_new_tokens=200) # litellm.set_verbose=True @@ -91,8 +97,8 @@ def replicate_test_completion(): # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", - messages=[{ "content": "Hello, how are you?","role": "user"}], - max_tokens=10 + messages=[{"content": "Hello, how are you?", "role": "user"}], + max_tokens=10, ) # Add any assertions here to check the response print(response_1) @@ -101,67 +107,80 @@ def replicate_test_completion(): # USE CONFIG TOKENS response_2 = litellm.completion( model="meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", - messages=[{ "content": "Hello, how are you?","role": "user"}], + messages=[{"content": "Hello, how are you?", "role": "user"}], ) # Add any assertions here to check the response print(response_2) response_2_text = response_2.choices[0].message.content assert len(response_2_text) > len(response_1_text) - try: - response_3 = litellm.completion(model="meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", - messages=[{ "content": "Hello, how are you?","role": "user"}], - n=2) - except: + try: + response_3 = litellm.completion( + model="meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", + messages=[{"content": "Hello, how are you?", "role": "user"}], + n=2, + ) + except: pass except Exception as e: pytest.fail(f"Error occurred: {e}") + # replicate_test_completion() # Cohere + def cohere_test_completion(): # litellm.CohereConfig(max_tokens=200) - litellm.set_verbose=True + litellm.set_verbose = True try: # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="command-nightly", - messages=[{ "content": "Hello, how are you?","role": "user"}], - max_tokens=10 + messages=[{"content": "Hello, how are you?", "role": "user"}], + max_tokens=10, ) response_1_text = response_1.choices[0].message.content # USE CONFIG TOKENS response_2 = litellm.completion( model="command-nightly", - messages=[{ "content": "Hello, how are you?","role": "user"}], + messages=[{"content": "Hello, how are you?", "role": "user"}], ) response_2_text = response_2.choices[0].message.content assert len(response_2_text) > len(response_1_text) - response_3 = litellm.completion(model="command-nightly", - messages=[{ "content": "Hello, how are you?","role": "user"}], - n=2) + response_3 = litellm.completion( + model="command-nightly", + messages=[{"content": "Hello, how are you?", "role": "user"}], + n=2, + ) assert len(response_3.choices) > 1 except Exception as e: pytest.fail(f"Error occurred: {e}") + # cohere_test_completion() # AI21 + def ai21_test_completion(): litellm.AI21Config(maxTokens=10) - litellm.set_verbose=True + litellm.set_verbose = True try: # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="j2-mid", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], - max_tokens=100 + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], + max_tokens=100, ) response_1_text = response_1.choices[0].message.content print(f"response_1_text: {response_1_text}") @@ -169,33 +188,47 @@ def ai21_test_completion(): # USE CONFIG TOKENS response_2 = litellm.completion( model="j2-mid", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], ) response_2_text = response_2.choices[0].message.content print(f"response_2_text: {response_2_text}") assert len(response_2_text) < len(response_1_text) - response_3 = litellm.completion(model="j2-light", - messages=[{ "content": "Hello, how are you?","role": "user"}], - n=2) + response_3 = litellm.completion( + model="j2-light", + messages=[{"content": "Hello, how are you?", "role": "user"}], + n=2, + ) assert len(response_3.choices) > 1 except Exception as e: pytest.fail(f"Error occurred: {e}") + # ai21_test_completion() # TogetherAI + def togetherai_test_completion(): litellm.TogetherAIConfig(max_tokens=10) - litellm.set_verbose=True + litellm.set_verbose = True try: # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="together_ai/togethercomputer/llama-2-70b-chat", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], - max_tokens=100 + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], + max_tokens=100, ) response_1_text = response_1.choices[0].message.content print(f"response_1_text: {response_1_text}") @@ -203,27 +236,36 @@ def togetherai_test_completion(): # USE CONFIG TOKENS response_2 = litellm.completion( model="together_ai/togethercomputer/llama-2-70b-chat", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], ) response_2_text = response_2.choices[0].message.content print(f"response_2_text: {response_2_text}") assert len(response_2_text) < len(response_1_text) - try: - response_3 = litellm.completion(model="together_ai/togethercomputer/llama-2-70b-chat", - messages=[{ "content": "Hello, how are you?","role": "user"}], - n=2) + try: + response_3 = litellm.completion( + model="together_ai/togethercomputer/llama-2-70b-chat", + messages=[{"content": "Hello, how are you?", "role": "user"}], + n=2, + ) pytest.fail(f"Error not raised when n=2 passed to provider") - except: + except: pass except Exception as e: pytest.fail(f"Error occurred: {e}") + # togetherai_test_completion() # Palm + def palm_test_completion(): litellm.PalmConfig(max_output_tokens=10, temperature=0.9) # litellm.set_verbose=True @@ -231,8 +273,13 @@ def palm_test_completion(): # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="palm/chat-bison", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], - max_tokens=100 + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], + max_tokens=100, ) response_1_text = response_1.choices[0].message.content print(f"response_1_text: {response_1_text}") @@ -240,24 +287,33 @@ def palm_test_completion(): # USE CONFIG TOKENS response_2 = litellm.completion( model="palm/chat-bison", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], ) response_2_text = response_2.choices[0].message.content print(f"response_2_text: {response_2_text}") assert len(response_2_text) < len(response_1_text) - response_3 = litellm.completion(model="palm/chat-bison", - messages=[{ "content": "Hello, how are you?","role": "user"}], - n=2) + response_3 = litellm.completion( + model="palm/chat-bison", + messages=[{"content": "Hello, how are you?", "role": "user"}], + n=2, + ) assert len(response_3.choices) > 1 except Exception as e: pytest.fail(f"Error occurred: {e}") + # palm_test_completion() # NLP Cloud + def nlp_cloud_test_completion(): litellm.NLPCloudConfig(max_length=10) # litellm.set_verbose=True @@ -265,8 +321,13 @@ def nlp_cloud_test_completion(): # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="dolphin", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], - max_tokens=100 + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], + max_tokens=100, ) response_1_text = response_1.choices[0].message.content print(f"response_1_text: {response_1_text}") @@ -274,27 +335,36 @@ def nlp_cloud_test_completion(): # USE CONFIG TOKENS response_2 = litellm.completion( model="dolphin", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], ) response_2_text = response_2.choices[0].message.content print(f"response_2_text: {response_2_text}") assert len(response_2_text) < len(response_1_text) - try: - response_3 = litellm.completion(model="dolphin", - messages=[{ "content": "Hello, how are you?","role": "user"}], - n=2) + try: + response_3 = litellm.completion( + model="dolphin", + messages=[{"content": "Hello, how are you?", "role": "user"}], + n=2, + ) pytest.fail(f"Error not raised when n=2 passed to provider") - except: + except: pass except Exception as e: pytest.fail(f"Error occurred: {e}") + # nlp_cloud_test_completion() # AlephAlpha + def aleph_alpha_test_completion(): litellm.AlephAlphaConfig(maximum_tokens=10) # litellm.set_verbose=True @@ -302,8 +372,13 @@ def aleph_alpha_test_completion(): # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="luminous-base", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], - max_tokens=100 + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], + max_tokens=100, ) response_1_text = response_1.choices[0].message.content print(f"response_1_text: {response_1_text}") @@ -311,24 +386,32 @@ def aleph_alpha_test_completion(): # USE CONFIG TOKENS response_2 = litellm.completion( model="luminous-base", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], ) response_2_text = response_2.choices[0].message.content print(f"response_2_text: {response_2_text}") assert len(response_2_text) < len(response_1_text) - response_3 = litellm.completion(model="luminous-base", - messages=[{ "content": "Hello, how are you?","role": "user"}], - n=2) - + response_3 = litellm.completion( + model="luminous-base", + messages=[{"content": "Hello, how are you?", "role": "user"}], + n=2, + ) + assert len(response_3.choices) > 1 except Exception as e: pytest.fail(f"Error occurred: {e}") + # aleph_alpha_test_completion() -# Petals - calls are too slow, will cause circle ci to fail due to delay. Test locally. +# Petals - calls are too slow, will cause circle ci to fail due to delay. Test locally. # def petals_completion(): # litellm.PetalsConfig(max_new_tokens=10) # # litellm.set_verbose=True @@ -359,7 +442,7 @@ def aleph_alpha_test_completion(): # petals_completion() # VertexAI -# We don't have vertex ai configured for circle ci yet -- need to figure this out. +# We don't have vertex ai configured for circle ci yet -- need to figure this out. # def vertex_ai_test_completion(): # litellm.VertexAIConfig(max_output_tokens=10) # # litellm.set_verbose=True @@ -389,6 +472,7 @@ def aleph_alpha_test_completion(): # Sagemaker + def sagemaker_test_completion(): litellm.SagemakerConfig(max_new_tokens=10) # litellm.set_verbose=True @@ -396,8 +480,13 @@ def sagemaker_test_completion(): # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], - max_tokens=100 + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], + max_tokens=100, ) response_1_text = response_1.choices[0].message.content print(f"response_1_text: {response_1_text}") @@ -405,7 +494,12 @@ def sagemaker_test_completion(): # USE CONFIG TOKENS response_2 = litellm.completion( model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], ) response_2_text = response_2.choices[0].message.content print(f"response_2_text: {response_2_text}") @@ -414,6 +508,7 @@ def sagemaker_test_completion(): except Exception as e: pytest.fail(f"Error occurred: {e}") + # sagemaker_test_completion() # Bedrock @@ -426,8 +521,13 @@ def bedrock_test_completion(): # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="bedrock/cohere.command-text-v14", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], - max_tokens=100 + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], + max_tokens=100, ) response_1_text = response_1.choices[0].message.content print(f"response_1_text: {response_1_text}") @@ -435,7 +535,12 @@ def bedrock_test_completion(): # USE CONFIG TOKENS response_2 = litellm.completion( model="bedrock/cohere.command-text-v14", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], ) response_2_text = response_2.choices[0].message.content print(f"response_2_text: {response_2_text}") @@ -446,8 +551,10 @@ def bedrock_test_completion(): except Exception as e: pytest.fail(f"Error occurred: {e}") + # bedrock_test_completion() + # OpenAI Chat Completion def openai_test_completion(): litellm.OpenAIConfig(max_tokens=10) @@ -456,8 +563,13 @@ def openai_test_completion(): # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="gpt-3.5-turbo", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], - max_tokens=100 + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], + max_tokens=100, ) response_1_text = response_1.choices[0].message.content print(f"response_1_text: {response_1_text}") @@ -465,7 +577,12 @@ def openai_test_completion(): # USE CONFIG TOKENS response_2 = litellm.completion( model="gpt-3.5-turbo", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], ) response_2_text = response_2.choices[0].message.content print(f"response_2_text: {response_2_text}") @@ -474,8 +591,10 @@ def openai_test_completion(): except Exception as e: pytest.fail(f"Error occurred: {e}") + # openai_test_completion() + # OpenAI Text Completion def openai_text_completion_test(): litellm.OpenAITextCompletionConfig(max_tokens=10) @@ -484,8 +603,13 @@ def openai_text_completion_test(): # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="text-davinci-003", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], - max_tokens=100 + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], + max_tokens=100, ) response_1_text = response_1.choices[0].message.content print(f"response_1_text: {response_1_text}") @@ -493,23 +617,32 @@ def openai_text_completion_test(): # USE CONFIG TOKENS response_2 = litellm.completion( model="text-davinci-003", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], ) response_2_text = response_2.choices[0].message.content print(f"response_2_text: {response_2_text}") assert len(response_2_text) < len(response_1_text) - response_3 = litellm.completion(model="text-davinci-003", - messages=[{ "content": "Hello, how are you?","role": "user"}], - n=2) + response_3 = litellm.completion( + model="text-davinci-003", + messages=[{"content": "Hello, how are you?", "role": "user"}], + n=2, + ) assert len(response_3.choices) > 1 except Exception as e: pytest.fail(f"Error occurred: {e}") + # openai_text_completion_test() -# Azure OpenAI + +# Azure OpenAI def azure_openai_test_completion(): litellm.AzureOpenAIConfig(max_tokens=10) # litellm.set_verbose=True @@ -517,8 +650,13 @@ def azure_openai_test_completion(): # OVERRIDE WITH DYNAMIC MAX TOKENS response_1 = litellm.completion( model="azure/chatgpt-v-2", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], - max_tokens=100 + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], + max_tokens=100, ) response_1_text = response_1.choices[0].message.content print(f"response_1_text: {response_1_text}") @@ -526,7 +664,12 @@ def azure_openai_test_completion(): # USE CONFIG TOKENS response_2 = litellm.completion( model="azure/chatgpt-v-2", - messages=[{ "content": "Hello, how are you? Be as verbose as possible","role": "user"}], + messages=[ + { + "content": "Hello, how are you? Be as verbose as possible", + "role": "user", + } + ], ) response_2_text = response_2.choices[0].message.content print(f"response_2_text: {response_2_text}") @@ -535,4 +678,5 @@ def azure_openai_test_completion(): except Exception as e: pytest.fail(f"Error occurred: {e}") -# azure_openai_test_completion() \ No newline at end of file + +# azure_openai_test_completion() diff --git a/litellm/tests/test_proxy_custom_auth.py b/litellm/tests/test_proxy_custom_auth.py index c96acb816..f16f1d379 100644 --- a/litellm/tests/test_proxy_custom_auth.py +++ b/litellm/tests/test_proxy_custom_auth.py @@ -9,7 +9,7 @@ import os, io sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system path import pytest import litellm from litellm import embedding, completion, completion_cost, Timeout @@ -18,7 +18,11 @@ from litellm import RateLimitError # test /chat/completion request to the proxy from fastapi.testclient import TestClient from fastapi import FastAPI -from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined +from litellm.proxy.proxy_server import ( + router, + save_worker_config, + initialize, +) # Replace with the actual module where your FastAPI router is defined # Here you create a fixture that will be used by your tests @@ -26,6 +30,7 @@ from litellm.proxy.proxy_server import router, save_worker_config, initialize # @pytest.fixture(scope="function") def client(): from litellm.proxy.proxy_server import cleanup_router_config_variables + cleanup_router_config_variables() filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" @@ -39,27 +44,22 @@ def client(): def test_custom_auth(client): try: - # Your test data + # Your test data test_data = { "model": "openai-model", "messages": [ - { - "role": "user", - "content": "hi" - }, + {"role": "user", "content": "hi"}, ], "max_tokens": 10, } # Your bearer token token = os.getenv("PROXY_MASTER_KEY") - headers = { - "Authorization": f"Bearer {token}" - } + headers = {"Authorization": f"Bearer {token}"} response = client.post("/chat/completions", json=test_data, headers=headers) print(f"response: {response.text}") assert response.status_code == 401 result = response.json() print(f"Received response: {result}") except Exception as e: - pytest.fail("LiteLLM Proxy test failed. Exception", e) \ No newline at end of file + pytest.fail("LiteLLM Proxy test failed. Exception", e) diff --git a/litellm/tests/test_proxy_custom_logger.py b/litellm/tests/test_proxy_custom_logger.py index 0a3097af9..4866c5b16 100644 --- a/litellm/tests/test_proxy_custom_logger.py +++ b/litellm/tests/test_proxy_custom_logger.py @@ -9,7 +9,7 @@ import os, io, asyncio sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system path import pytest import litellm from litellm import embedding, completion, completion_cost, Timeout @@ -19,16 +19,22 @@ import importlib, inspect # test /chat/completion request to the proxy from fastapi.testclient import TestClient from fastapi import FastAPI -from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined +from litellm.proxy.proxy_server import ( + router, + save_worker_config, + initialize, +) # Replace with the actual module where your FastAPI router is defined + filepath = os.path.dirname(os.path.abspath(__file__)) python_file_path = f"{filepath}/test_configs/custom_callbacks.py" # @app.on_event("startup") # async def wrapper_startup_event(): - # initialize(config=config_fp) +# initialize(config=config_fp) # Use the app fixture in your client fixture + @pytest.fixture def client(): filepath = os.path.dirname(os.path.abspath(__file__)) @@ -38,25 +44,23 @@ def client(): app.include_router(router) # Include your router in the test app return TestClient(app) - # Your bearer token token = os.getenv("PROXY_MASTER_KEY") -headers = { - "Authorization": f"Bearer {token}" -} +headers = {"Authorization": f"Bearer {token}"} print("Testing proxy custom logger") + def test_embedding(client): try: - litellm.set_verbose=False + litellm.set_verbose = False from litellm.proxy.utils import get_instance_fn + my_custom_logger = get_instance_fn( - value = "custom_callbacks.my_custom_logger", - config_file_path=python_file_path + value="custom_callbacks.my_custom_logger", config_file_path=python_file_path ) print("id of initialized custom logger", id(my_custom_logger)) litellm.callbacks = [my_custom_logger] @@ -69,26 +73,50 @@ def test_embedding(client): print("my_custom_logger", my_custom_logger) assert my_custom_logger.async_success_embedding == False - test_data = { - "model": "azure-embedding-model", - "input": ["hello"] - } + test_data = {"model": "azure-embedding-model", "input": ["hello"]} response = client.post("/embeddings", json=test_data, headers=headers) print("made request", response.status_code, response.text) - print("vars my custom logger /embeddings", vars(my_custom_logger), "id", id(my_custom_logger)) - assert my_custom_logger.async_success_embedding == True # checks if the status of async_success is True, only the async_log_success_event can set this to true - assert my_custom_logger.async_embedding_kwargs["model"] == "azure-embedding-model" # checks if kwargs passed to async_log_success_event are correct + print( + "vars my custom logger /embeddings", + vars(my_custom_logger), + "id", + id(my_custom_logger), + ) + assert ( + my_custom_logger.async_success_embedding == True + ) # checks if the status of async_success is True, only the async_log_success_event can set this to true + assert ( + my_custom_logger.async_embedding_kwargs["model"] == "azure-embedding-model" + ) # checks if kwargs passed to async_log_success_event are correct kwargs = my_custom_logger.async_embedding_kwargs litellm_params = kwargs.get("litellm_params") metadata = litellm_params.get("metadata", None) print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata")) assert metadata is not None assert "user_api_key" in metadata - assert "headers" in metadata + assert "headers" in metadata proxy_server_request = litellm_params.get("proxy_server_request") model_info = litellm_params.get("model_info") - assert proxy_server_request == {'url': 'http://testserver/embeddings', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '54', 'content-type': 'application/json'}, 'body': {'model': 'azure-embedding-model', 'input': ['hello']}} - assert model_info == {'input_cost_per_token': 0.002, 'mode': 'embedding', 'id': 'hello'} + assert proxy_server_request == { + "url": "http://testserver/embeddings", + "method": "POST", + "headers": { + "host": "testserver", + "accept": "*/*", + "accept-encoding": "gzip, deflate", + "connection": "keep-alive", + "user-agent": "testclient", + "authorization": "Bearer sk-1234", + "content-length": "54", + "content-type": "application/json", + }, + "body": {"model": "azure-embedding-model", "input": ["hello"]}, + } + assert model_info == { + "input_cost_per_token": 0.002, + "mode": "embedding", + "id": "hello", + } result = response.json() print(f"Received response: {result}") print("Passed Embedding custom logger on proxy!") @@ -98,12 +126,12 @@ def test_embedding(client): def test_chat_completion(client): try: - # Your test data - litellm.set_verbose=False + # Your test data + litellm.set_verbose = False from litellm.proxy.utils import get_instance_fn + my_custom_logger = get_instance_fn( - value = "custom_callbacks.my_custom_logger", - config_file_path=python_file_path + value="custom_callbacks.my_custom_logger", config_file_path=python_file_path ) print("id of initialized custom logger", id(my_custom_logger)) @@ -121,36 +149,66 @@ def test_chat_completion(client): test_data = { "model": "Azure OpenAI GPT-4 Canada", "messages": [ - { - "role": "user", - "content": "write a litellm poem" - }, + {"role": "user", "content": "write a litellm poem"}, ], "max_tokens": 10, } - response = client.post("/chat/completions", json=test_data, headers=headers) print("made request", response.status_code, response.text) print("LiteLLM Callbacks", litellm.callbacks) - asyncio.sleep(1) # sleep while waiting for callback to run + asyncio.sleep(1) # sleep while waiting for callback to run - print("my_custom_logger in /chat/completions", my_custom_logger, "id", id(my_custom_logger)) + print( + "my_custom_logger in /chat/completions", + my_custom_logger, + "id", + id(my_custom_logger), + ) print("vars my custom logger, ", vars(my_custom_logger)) - assert my_custom_logger.async_success == True # checks if the status of async_success is True, only the async_log_success_event can set this to true - assert my_custom_logger.async_completion_kwargs["model"] == "chatgpt-v-2" # checks if kwargs passed to async_log_success_event are correct - print("\n\n Custom Logger Async Completion args", my_custom_logger.async_completion_kwargs) + assert ( + my_custom_logger.async_success == True + ) # checks if the status of async_success is True, only the async_log_success_event can set this to true + assert ( + my_custom_logger.async_completion_kwargs["model"] == "chatgpt-v-2" + ) # checks if kwargs passed to async_log_success_event are correct + print( + "\n\n Custom Logger Async Completion args", + my_custom_logger.async_completion_kwargs, + ) litellm_params = my_custom_logger.async_completion_kwargs.get("litellm_params") metadata = litellm_params.get("metadata", None) print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata")) assert metadata is not None assert "user_api_key" in metadata - assert "headers" in metadata + assert "headers" in metadata config_model_info = litellm_params.get("model_info") proxy_server_request_object = litellm_params.get("proxy_server_request") - assert config_model_info == {'id': 'gm', 'input_cost_per_token': 0.0002, 'mode': 'chat'} - assert proxy_server_request_object == {'url': 'http://testserver/chat/completions', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '123', 'content-type': 'application/json'}, 'body': {'model': 'Azure OpenAI GPT-4 Canada', 'messages': [{'role': 'user', 'content': 'write a litellm poem'}], 'max_tokens': 10}} + assert config_model_info == { + "id": "gm", + "input_cost_per_token": 0.0002, + "mode": "chat", + } + assert proxy_server_request_object == { + "url": "http://testserver/chat/completions", + "method": "POST", + "headers": { + "host": "testserver", + "accept": "*/*", + "accept-encoding": "gzip, deflate", + "connection": "keep-alive", + "user-agent": "testclient", + "authorization": "Bearer sk-1234", + "content-length": "123", + "content-type": "application/json", + }, + "body": { + "model": "Azure OpenAI GPT-4 Canada", + "messages": [{"role": "user", "content": "write a litellm poem"}], + "max_tokens": 10, + }, + } result = response.json() print(f"Received response: {result}") print("\nPassed /chat/completions with Custom Logger!") @@ -161,40 +219,38 @@ def test_chat_completion(client): def test_chat_completion_stream(client): try: # Your test data - litellm.set_verbose=False + litellm.set_verbose = False from litellm.proxy.utils import get_instance_fn + my_custom_logger = get_instance_fn( - value = "custom_callbacks.my_custom_logger", - config_file_path=python_file_path + value="custom_callbacks.my_custom_logger", config_file_path=python_file_path ) print("id of initialized custom logger", id(my_custom_logger)) litellm.callbacks = [my_custom_logger] import json + print("initialized proxy") # import the initialized custom logger print(litellm.callbacks) - print("LiteLLM Callbacks", litellm.callbacks) print("my_custom_logger", my_custom_logger) - assert my_custom_logger.streaming_response_obj == None # no streaming response obj is set pre call + assert ( + my_custom_logger.streaming_response_obj == None + ) # no streaming response obj is set pre call test_data = { "model": "Azure OpenAI GPT-4 Canada", "messages": [ - { - "role": "user", - "content": "write 1 line poem about LiteLLM" - }, + {"role": "user", "content": "write 1 line poem about LiteLLM"}, ], "max_tokens": 40, - "stream": True # streaming call + "stream": True, # streaming call } - response = client.post("/chat/completions", json=test_data, headers=headers) print("made request", response.status_code, response.text) complete_response = "" @@ -205,7 +261,7 @@ def test_chat_completion_stream(client): print(line) line = str(line) - json_data = line.replace('data: ', '') + json_data = line.replace("data: ", "") # Parse the JSON string data = json.loads(json_data) @@ -213,22 +269,24 @@ def test_chat_completion_stream(client): print("\n\n decode_data", data) # Access the content of choices[0]['message']['content'] - content = data['choices'][0]['delta']['content'] or "" + content = data["choices"][0]["delta"]["content"] or "" # Process the content as needed print("Content:", content) - complete_response+= content + complete_response += content print("\n\nHERE is the complete streaming response string", complete_response) print("\n\nHERE IS the streaming Response from callback\n\n") print(my_custom_logger.streaming_response_obj) import time + time.sleep(0.5) streamed_response = my_custom_logger.streaming_response_obj - assert complete_response == streamed_response["choices"][0]["message"]["content"] + assert ( + complete_response == streamed_response["choices"][0]["message"]["content"] + ) except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") - diff --git a/litellm/tests/test_proxy_exception_mapping.py b/litellm/tests/test_proxy_exception_mapping.py index c5f99f28c..ff3b358a9 100644 --- a/litellm/tests/test_proxy_exception_mapping.py +++ b/litellm/tests/test_proxy_exception_mapping.py @@ -5,14 +5,20 @@ from dotenv import load_dotenv load_dotenv() import os, io, asyncio + sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system path import pytest import litellm, openai from fastapi.testclient import TestClient from fastapi import FastAPI -from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined +from litellm.proxy.proxy_server import ( + router, + save_worker_config, + initialize, +) # Replace with the actual module where your FastAPI router is defined + @pytest.fixture def client(): @@ -23,6 +29,7 @@ def client(): app.include_router(router) # Include your router in the test app return TestClient(app) + # raise openai.AuthenticationError def test_chat_completion_exception(client): try: @@ -30,10 +37,7 @@ def test_chat_completion_exception(client): test_data = { "model": "gpt-3.5-turbo", "messages": [ - { - "role": "user", - "content": "hi" - }, + {"role": "user", "content": "hi"}, ], "max_tokens": 10, } @@ -42,12 +46,15 @@ def test_chat_completion_exception(client): # make an openai client to call _make_status_error_from_response openai_client = openai.OpenAI(api_key="anything") - openai_exception = openai_client._make_status_error_from_response(response=response) + openai_exception = openai_client._make_status_error_from_response( + response=response + ) assert isinstance(openai_exception, openai.AuthenticationError) except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") + # raise openai.AuthenticationError def test_chat_completion_exception_azure(client): try: @@ -55,10 +62,7 @@ def test_chat_completion_exception_azure(client): test_data = { "model": "azure-gpt-3.5-turbo", "messages": [ - { - "role": "user", - "content": "hi" - }, + {"role": "user", "content": "hi"}, ], "max_tokens": 10, } @@ -67,7 +71,9 @@ def test_chat_completion_exception_azure(client): # make an openai client to call _make_status_error_from_response openai_client = openai.OpenAI(api_key="anything") - openai_exception = openai_client._make_status_error_from_response(response=response) + openai_exception = openai_client._make_status_error_from_response( + response=response + ) print(openai_exception) assert isinstance(openai_exception, openai.AuthenticationError) @@ -79,17 +85,16 @@ def test_chat_completion_exception_azure(client): def test_embedding_auth_exception_azure(client): try: # Your test data - test_data = { - "model": "azure-embedding", - "input": ["hi"] - } + test_data = {"model": "azure-embedding", "input": ["hi"]} response = client.post("/embeddings", json=test_data) print("Response from proxy=", response) # make an openai client to call _make_status_error_from_response openai_client = openai.OpenAI(api_key="anything") - openai_exception = openai_client._make_status_error_from_response(response=response) + openai_exception = openai_client._make_status_error_from_response( + response=response + ) print("Exception raised=", openai_exception) assert isinstance(openai_exception, openai.AuthenticationError) @@ -97,8 +102,6 @@ def test_embedding_auth_exception_azure(client): pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") - - # raise openai.BadRequestError # chat/completions openai def test_exception_openai_bad_model(client): @@ -107,10 +110,7 @@ def test_exception_openai_bad_model(client): test_data = { "model": "azure/GPT-12", "messages": [ - { - "role": "user", - "content": "hi" - }, + {"role": "user", "content": "hi"}, ], "max_tokens": 10, } @@ -119,13 +119,16 @@ def test_exception_openai_bad_model(client): # make an openai client to call _make_status_error_from_response openai_client = openai.OpenAI(api_key="anything") - openai_exception = openai_client._make_status_error_from_response(response=response) + openai_exception = openai_client._make_status_error_from_response( + response=response + ) print("Type of exception=", type(openai_exception)) assert isinstance(openai_exception, openai.NotFoundError) except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") + # chat/completions any model def test_chat_completion_exception_any_model(client): try: @@ -133,10 +136,7 @@ def test_chat_completion_exception_any_model(client): test_data = { "model": "Lite-GPT-12", "messages": [ - { - "role": "user", - "content": "hi" - }, + {"role": "user", "content": "hi"}, ], "max_tokens": 10, } @@ -145,7 +145,9 @@ def test_chat_completion_exception_any_model(client): # make an openai client to call _make_status_error_from_response openai_client = openai.OpenAI(api_key="anything") - openai_exception = openai_client._make_status_error_from_response(response=response) + openai_exception = openai_client._make_status_error_from_response( + response=response + ) print("Exception raised=", openai_exception) assert isinstance(openai_exception, openai.NotFoundError) @@ -153,26 +155,22 @@ def test_chat_completion_exception_any_model(client): pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") - # embeddings any model def test_embedding_exception_any_model(client): try: # Your test data - test_data = { - "model": "Lite-GPT-12", - "input": ["hi"] - } + test_data = {"model": "Lite-GPT-12", "input": ["hi"]} response = client.post("/embeddings", json=test_data) print("Response from proxy=", response) # make an openai client to call _make_status_error_from_response openai_client = openai.OpenAI(api_key="anything") - openai_exception = openai_client._make_status_error_from_response(response=response) + openai_exception = openai_client._make_status_error_from_response( + response=response + ) print("Exception raised=", openai_exception) assert isinstance(openai_exception, openai.NotFoundError) except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") - - diff --git a/litellm/tests/test_proxy_gunicorn.py b/litellm/tests/test_proxy_gunicorn.py index 9afd424d4..73e368d35 100644 --- a/litellm/tests/test_proxy_gunicorn.py +++ b/litellm/tests/test_proxy_gunicorn.py @@ -1,6 +1,6 @@ # #### What this tests #### # # Allow the user to easily run the local proxy server with Gunicorn -# # LOCAL TESTING ONLY +# # LOCAL TESTING ONLY # import sys, os, subprocess # import traceback # from dotenv import load_dotenv @@ -12,15 +12,15 @@ # sys.path.insert( # 0, os.path.abspath("../..") -# ) # Adds the parent directory to the system path +# ) # Adds the parent directory to the system path # import pytest -# import litellm +# import litellm # ### LOCAL Proxy Server INIT ### # from litellm.proxy.proxy_server import save_worker_config # Replace with the actual module where your FastAPI router is defined # filepath = os.path.dirname(os.path.abspath(__file__)) # config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" -# def get_openai_info(): +# def get_openai_info(): # return { # "api_key": os.getenv("AZURE_API_KEY"), # "api_base": os.getenv("AZURE_API_BASE"), @@ -41,7 +41,7 @@ # os.environ["AZURE_API_BASE"] = azure_info['api_base'] # os.environ["AZURE_API_VERSION"] = "2023-09-01-preview" -# ### SAVE CONFIG ### +# ### SAVE CONFIG ### # os.environ["WORKER_CONFIG"] = config_fp @@ -58,4 +58,4 @@ # subprocess.run(cmd) # This line actually starts Gunicorn # if __name__ == "__main__": -# run_server() \ No newline at end of file +# run_server() diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index bdd232f0b..22394e01b 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -9,11 +9,12 @@ import os, io sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system path import pytest, logging import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError + # Configure logging logging.basicConfig( level=logging.DEBUG, # Set the desired logging level @@ -23,19 +24,23 @@ logging.basicConfig( # test /chat/completion request to the proxy from fastapi.testclient import TestClient from fastapi import FastAPI -from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined +from litellm.proxy.proxy_server import ( + router, + save_worker_config, + initialize, +) # Replace with the actual module where your FastAPI router is defined # Your bearer token token = "" -headers = { - "Authorization": f"Bearer {token}" -} - +headers = {"Authorization": f"Bearer {token}"} + + @pytest.fixture(scope="function") def client_no_auth(): # Assuming litellm.proxy.proxy_server is an object from litellm.proxy.proxy_server import cleanup_router_config_variables + cleanup_router_config_variables() filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" @@ -46,6 +51,7 @@ def client_no_auth(): return TestClient(app) + def test_chat_completion(client_no_auth): global headers try: @@ -53,14 +59,11 @@ def test_chat_completion(client_no_auth): test_data = { "model": "gpt-3.5-turbo", "messages": [ - { - "role": "user", - "content": "hi" - }, + {"role": "user", "content": "hi"}, ], "max_tokens": 10, } - + print("testing proxy server with chat completions") response = client_no_auth.post("/v1/chat/completions", json=test_data) print(f"response - {response.text}") @@ -70,41 +73,41 @@ def test_chat_completion(client_no_auth): except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + # Run the test -def test_chat_completion_azure(client_no_auth): +def test_chat_completion_azure(client_no_auth): global headers try: # Your test data test_data = { "model": "azure/chatgpt-v-2", "messages": [ - { - "role": "user", - "content": "write 1 sentence poem" - }, + {"role": "user", "content": "write 1 sentence poem"}, ], "max_tokens": 10, } - + print("testing proxy server with Azure Request /chat/completions") response = client_no_auth.post("/v1/chat/completions", json=test_data) assert response.status_code == 200 result = response.json() print(f"Received response: {result}") - assert len(result["choices"][0]["message"]["content"]) > 0 + assert len(result["choices"][0]["message"]["content"]) > 0 except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + # Run the test # test_chat_completion_azure() -### EMBEDDING + +### EMBEDDING def test_embedding(client_no_auth): global headers - from litellm.proxy.proxy_server import user_custom_auth + from litellm.proxy.proxy_server import user_custom_auth try: test_data = { @@ -117,13 +120,14 @@ def test_embedding(client_no_auth): assert response.status_code == 200 result = response.json() print(len(result["data"][0]["embedding"])) - assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so + assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + def test_bedrock_embedding(client_no_auth): global headers - from litellm.proxy.proxy_server import user_custom_auth + from litellm.proxy.proxy_server import user_custom_auth try: test_data = { @@ -136,13 +140,14 @@ def test_bedrock_embedding(client_no_auth): assert response.status_code == 200 result = response.json() print(len(result["data"][0]["embedding"])) - assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so + assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + def test_sagemaker_embedding(client_no_auth): global headers - from litellm.proxy.proxy_server import user_custom_auth + from litellm.proxy.proxy_server import user_custom_auth try: test_data = { @@ -155,24 +160,26 @@ def test_sagemaker_embedding(client_no_auth): assert response.status_code == 200 result = response.json() print(len(result["data"][0]["embedding"])) - assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so + assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + # Run the test # test_embedding() #### IMAGE GENERATION - + + def test_img_gen(client_no_auth): global headers - from litellm.proxy.proxy_server import user_custom_auth + from litellm.proxy.proxy_server import user_custom_auth try: test_data = { "model": "dall-e-3", "prompt": "A cute baby sea otter", "n": 1, - "size": "1024x1024" + "size": "1024x1024", } response = client_no_auth.post("/v1/images/generations", json=test_data) @@ -180,41 +187,41 @@ def test_img_gen(client_no_auth): assert response.status_code == 200 result = response.json() print(len(result["data"][0]["url"])) - assert len(result["data"][0]["url"]) > 10 + assert len(result["data"][0]["url"]) > 10 except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") -#### ADDITIONAL + +#### ADDITIONAL # @pytest.mark.skip(reason="hitting yaml load issues on circle-ci") def test_add_new_model(client_no_auth): global headers - try: + try: test_data = { "model_name": "test_openai_models", "litellm_params": { - "model": "gpt-3.5-turbo", + "model": "gpt-3.5-turbo", }, - "model_info": { - "description": "this is a test openai model" - } + "model_info": {"description": "this is a test openai model"}, } client_no_auth.post("/model/new", json=test_data, headers=headers) response = client_no_auth.get("/model/info", headers=headers) assert response.status_code == 200 - result = response.json() + result = response.json() print(f"response: {result}") model_info = None for m in result["data"]: if m["model_name"] == "test_openai_models": model_info = m["model_info"] assert model_info["description"] == "this is a test openai model" - except Exception as e: + except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") def test_health(client_no_auth): global headers import time + try: response = client_no_auth.get("/health") assert response.status_code == 200 @@ -222,19 +229,24 @@ def test_health(client_no_auth): assert result["unhealthy_count"] == 0 except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + + # test_add_new_model() from litellm.integrations.custom_logger import CustomLogger + + class MyCustomHandler(CustomLogger): - def log_pre_api_call(self, model, messages, kwargs): + def log_pre_api_call(self, model, messages, kwargs): print(f"Pre-API Call") - def log_success_event(self, kwargs, response_obj, start_time, end_time): + def log_success_event(self, kwargs, response_obj, start_time, end_time): print(f"On Success") assert kwargs["user"] == "proxy-user" assert kwargs["model"] == "gpt-3.5-turbo" assert kwargs["max_tokens"] == 10 + customHandler = MyCustomHandler() @@ -243,19 +255,16 @@ def test_chat_completion_optional_params(client_no_auth): # This tests if all the /chat/completion params are passed to litellm try: # Your test data - litellm.set_verbose=True + litellm.set_verbose = True test_data = { "model": "gpt-3.5-turbo", "messages": [ - { - "role": "user", - "content": "hi" - }, + {"role": "user", "content": "hi"}, ], "max_tokens": 10, - "user": "proxy-user" + "user": "proxy-user", } - + litellm.callbacks = [customHandler] print("testing proxy server: optional params") response = client_no_auth.post("/v1/chat/completions", json=test_data) @@ -265,28 +274,39 @@ def test_chat_completion_optional_params(client_no_auth): except Exception as e: pytest.fail("LiteLLM Proxy test failed. Exception", e) + # Run the test # test_chat_completion_optional_params() -# Test Reading config.yaml file +# Test Reading config.yaml file from litellm.proxy.proxy_server import load_router_config + def test_load_router_config(): try: print("testing reading config") # this is a basic config.yaml with only a model filepath = os.path.dirname(os.path.abspath(__file__)) - result = load_router_config(router=None, config_file_path=f"{filepath}/example_config_yaml/simple_config.yaml") + result = load_router_config( + router=None, + config_file_path=f"{filepath}/example_config_yaml/simple_config.yaml", + ) print(result) assert len(result[1]) == 1 # this is a load balancing config yaml - result = load_router_config(router=None, config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml") + result = load_router_config( + router=None, + config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml", + ) print(result) assert len(result[1]) == 2 # config with general settings - custom callbacks - result = load_router_config(router=None, config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml") + result = load_router_config( + router=None, + config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml", + ) print(result) assert len(result[1]) == 2 @@ -295,24 +315,38 @@ def test_load_router_config(): litellm.cache = None load_router_config( router=None, - config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml" + config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml", ) assert litellm.cache is not None - assert "redis_client" in vars(litellm.cache.cache) # it should default to redis on proxy - assert litellm.cache.supported_call_types == ['completion', 'acompletion', 'embedding', 'aembedding'] # init with all call types - + assert "redis_client" in vars( + litellm.cache.cache + ) # it should default to redis on proxy + assert litellm.cache.supported_call_types == [ + "completion", + "acompletion", + "embedding", + "aembedding", + ] # init with all call types + print("testing reading proxy config for cache with params") load_router_config( router=None, - config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml" + config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml", ) assert litellm.cache is not None print(litellm.cache) print(litellm.cache.supported_call_types) print(vars(litellm.cache.cache)) - assert "redis_client" in vars(litellm.cache.cache) # it should default to redis on proxy - assert litellm.cache.supported_call_types == ['embedding', 'aembedding'] # init with all call types + assert "redis_client" in vars( + litellm.cache.cache + ) # it should default to redis on proxy + assert litellm.cache.supported_call_types == [ + "embedding", + "aembedding", + ] # init with all call types except Exception as e: pytest.fail("Proxy: Got exception reading config", e) -# test_load_router_config() \ No newline at end of file + + +# test_load_router_config() diff --git a/litellm/tests/test_proxy_server_caching.py b/litellm/tests/test_proxy_server_caching.py index 75b017e50..f37cd9b58 100644 --- a/litellm/tests/test_proxy_server_caching.py +++ b/litellm/tests/test_proxy_server_caching.py @@ -1,5 +1,5 @@ # #### What this tests #### -# # This tests using caching w/ litellm which requires SSL=True +# # This tests using caching w/ litellm which requires SSL=True # import sys, os # import time @@ -35,4 +35,4 @@ # print(f"error occurred: {traceback.format_exc()}") # pytest.fail(f"Error occurred: {e}") -# test_caching_v2() \ No newline at end of file +# test_caching_v2() diff --git a/litellm/tests/test_proxy_server_cost.py b/litellm/tests/test_proxy_server_cost.py index b127e72e3..f6cf11ada 100644 --- a/litellm/tests/test_proxy_server_cost.py +++ b/litellm/tests/test_proxy_server_cost.py @@ -30,22 +30,22 @@ # yield client # @pytest.mark.asyncio -# async def test_proxy_cost_tracking(client): +# async def test_proxy_cost_tracking(client): # """ -# Get min cost. +# Get min cost. # Create new key. -# Run 10 parallel calls. -# Check cost for key at the end. -# assert it's > min cost. +# Run 10 parallel calls. +# Check cost for key at the end. +# assert it's > min cost. # """ # model = "gpt-3.5-turbo" # messages = [{"role": "user", "content": "Hey, how's it going?"}] # number_of_calls = 1 # min_cost = litellm.completion_cost(model=model, messages=messages) * number_of_calls -# try: +# try: # ### CREATE NEW KEY ### # test_data = { -# "models": ["azure-model"], +# "models": ["azure-model"], # } # # Your bearer token # token = os.getenv("PROXY_MASTER_KEY") @@ -57,7 +57,7 @@ # key = create_new_key.json()["key"] # print(f"received key: {key}") # ### MAKE PARALLEL CALLS ### -# async def test_chat_completions(): +# async def test_chat_completions(): # # Your test data # test_data = { # "model": "azure-model", diff --git a/litellm/tests/test_proxy_server_keys.py b/litellm/tests/test_proxy_server_keys.py index 14b239ae1..62bdfeb69 100644 --- a/litellm/tests/test_proxy_server_keys.py +++ b/litellm/tests/test_proxy_server_keys.py @@ -9,63 +9,104 @@ import os, io sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system path import pytest, logging import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError + # Configure logging logging.basicConfig( level=logging.DEBUG, # Set the desired logging level format="%(asctime)s - %(levelname)s - %(message)s", ) from concurrent.futures import ThreadPoolExecutor + # test /chat/completion request to the proxy from fastapi.testclient import TestClient from fastapi import FastAPI -from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined +from litellm.proxy.proxy_server import ( + router, + save_worker_config, + startup_event, +) # Replace with the actual module where your FastAPI router is defined + filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_config.yaml" -save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) +save_worker_config( + config=config_fp, + model=None, + alias=None, + api_base=None, + api_version=None, + debug=False, + temperature=None, + max_tokens=None, + request_timeout=600, + max_budget=None, + telemetry=False, + drop_params=True, + add_function_to_prompt=False, + headers=None, + save=False, + use_queue=False, +) app = FastAPI() app.include_router(router) # Include your router in the test app + + @app.on_event("startup") async def wrapper_startup_event(): await startup_event() + # Here you create a fixture that will be used by your tests # Make sure the fixture returns TestClient(app) @pytest.fixture(autouse=True) def client(): from litellm.proxy.proxy_server import cleanup_router_config_variables + cleanup_router_config_variables() with TestClient(app) as client: yield client + def test_add_new_key(client): try: # Your test data test_data = { - "models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"], - "aliases": {"mistral-7b": "gpt-3.5-turbo"}, - "duration": "20m" + "models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"], + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": "20m", } print("testing proxy server") # Your bearer token token = os.getenv("PROXY_MASTER_KEY") - headers = { - "Authorization": f"Bearer {token}" - } + headers = {"Authorization": f"Bearer {token}"} response = client.post("/key/generate", json=test_data, headers=headers) print(f"response: {response.text}") assert response.status_code == 200 result = response.json() assert result["key"].startswith("sk-") + def _post_data(): - json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]} - response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"}) + json_data = { + "model": "azure-model", + "messages": [ + { + "role": "user", + "content": f"this is a test request, write a short poem {time.time()}", + } + ], + } + response = client.post( + "/chat/completions", + json=json_data, + headers={"Authorization": f"Bearer {result['key']}"}, + ) return response + _post_data() print(f"Received response: {result}") except Exception as e: @@ -76,33 +117,34 @@ def test_update_new_key(client): try: # Your test data test_data = { - "models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"], - "aliases": {"mistral-7b": "gpt-3.5-turbo"}, - "duration": "20m" + "models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"], + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": "20m", } print("testing proxy server") # Your bearer token token = os.getenv("PROXY_MASTER_KEY") - headers = { - "Authorization": f"Bearer {token}" - } + headers = {"Authorization": f"Bearer {token}"} response = client.post("/key/generate", json=test_data, headers=headers) print(f"response: {response.text}") assert response.status_code == 200 result = response.json() assert result["key"].startswith("sk-") + def _post_data(): - json_data = {'models': ['bedrock-models'], "key": result["key"]} + json_data = {"models": ["bedrock-models"], "key": result["key"]} response = client.post("/key/update", json=json_data, headers=headers) print(f"response text: {response.text}") assert response.status_code == 200 return response + _post_data() print(f"Received response: {result}") except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") + # # Run the test - only runs via pytest @@ -113,17 +155,29 @@ def test_add_new_key_max_parallel_limit(client): # Your bearer token token = os.getenv("PROXY_MASTER_KEY") - headers = { - "Authorization": f"Bearer {token}" - } + headers = {"Authorization": f"Bearer {token}"} response = client.post("/key/generate", json=test_data, headers=headers) print(f"response: {response.text}") assert response.status_code == 200 result = response.json() + def _post_data(): - json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]} - response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"}) + json_data = { + "model": "azure-model", + "messages": [ + { + "role": "user", + "content": f"this is a test request, write a short poem {time.time()}", + } + ], + } + response = client.post( + "/chat/completions", + json=json_data, + headers={"Authorization": f"Bearer {result['key']}"}, + ) return response + def _run_in_parallel(): with ThreadPoolExecutor(max_workers=2) as executor: future1 = executor.submit(_post_data) @@ -134,30 +188,45 @@ def test_add_new_key_max_parallel_limit(client): response2 = future2.result() if response1.status_code == 429 or response2.status_code == 429: pass - else: + else: raise Exception() + _run_in_parallel() except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") + def test_add_new_key_max_parallel_limit_streaming(client): try: # Your test data test_data = {"duration": "20m", "max_parallel_requests": 1} # Your bearer token - token = os.getenv('PROXY_MASTER_KEY') + token = os.getenv("PROXY_MASTER_KEY") - headers = { - "Authorization": f"Bearer {token}" - } + headers = {"Authorization": f"Bearer {token}"} response = client.post("/key/generate", json=test_data, headers=headers) print(f"response: {response.text}") assert response.status_code == 200 result = response.json() + def _post_data(): - json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}], "stream": True} - response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"}) + json_data = { + "model": "azure-model", + "messages": [ + { + "role": "user", + "content": f"this is a test request, write a short poem {time.time()}", + } + ], + "stream": True, + } + response = client.post( + "/chat/completions", + json=json_data, + headers={"Authorization": f"Bearer {result['key']}"}, + ) return response + def _run_in_parallel(): with ThreadPoolExecutor(max_workers=2) as executor: future1 = executor.submit(_post_data) @@ -168,8 +237,9 @@ def test_add_new_key_max_parallel_limit_streaming(client): response2 = future2.result() if response1.status_code == 429 or response2.status_code == 429: pass - else: + else: raise Exception() + _run_in_parallel() except Exception as e: - pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") \ No newline at end of file + pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") diff --git a/litellm/tests/test_proxy_server_langfuse.py b/litellm/tests/test_proxy_server_langfuse.py index 0e21a8eb9..4f896f792 100644 --- a/litellm/tests/test_proxy_server_langfuse.py +++ b/litellm/tests/test_proxy_server_langfuse.py @@ -9,11 +9,12 @@ import os, io sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system path import pytest, logging import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError + # Configure logging logging.basicConfig( level=logging.DEBUG, # Set the desired logging level @@ -23,16 +24,41 @@ logging.basicConfig( # test /chat/completion request to the proxy from fastapi.testclient import TestClient from fastapi import FastAPI -from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined +from litellm.proxy.proxy_server import ( + router, + save_worker_config, + startup_event, +) # Replace with the actual module where your FastAPI router is defined + filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_config.yaml" -save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) +save_worker_config( + config=config_fp, + model=None, + alias=None, + api_base=None, + api_version=None, + debug=False, + temperature=None, + max_tokens=None, + request_timeout=600, + max_budget=None, + telemetry=False, + drop_params=True, + add_function_to_prompt=False, + headers=None, + save=False, + use_queue=False, +) app = FastAPI() app.include_router(router) # Include your router in the test app + + @app.on_event("startup") async def wrapper_startup_event(): await startup_event() + # Here you create a fixture that will be used by your tests # Make sure the fixture returns TestClient(app) @pytest.fixture(autouse=True) @@ -40,16 +66,14 @@ def client(): with TestClient(app) as client: yield client + def test_chat_completion(client): try: # Your test data test_data = { "model": "gpt-3.5-turbo", "messages": [ - { - "role": "user", - "content": "hi" - }, + {"role": "user", "content": "hi"}, ], "max_tokens": 10, } diff --git a/litellm/tests/test_proxy_server_spend.py b/litellm/tests/test_proxy_server_spend.py index f64ad8987..9fed60412 100644 --- a/litellm/tests/test_proxy_server_spend.py +++ b/litellm/tests/test_proxy_server_spend.py @@ -79,4 +79,4 @@ # # print(json.dumps(super_fake_response.model_dump(), indent=4)) -# asyncio.run(loadtest_fn()) \ No newline at end of file +# asyncio.run(loadtest_fn()) diff --git a/litellm/tests/test_register_model.py b/litellm/tests/test_register_model.py index 185e96c20..6b1707988 100644 --- a/litellm/tests/test_register_model.py +++ b/litellm/tests/test_register_model.py @@ -4,44 +4,62 @@ import sys, os import traceback import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import litellm + def test_update_model_cost(): - try: - litellm.register_model({ - "gpt-4": { - "max_tokens": 8192, - "input_cost_per_token": 0.00002, - "output_cost_per_token": 0.00006, - "litellm_provider": "openai", - "mode": "chat" - }, - }) + try: + litellm.register_model( + { + "gpt-4": { + "max_tokens": 8192, + "input_cost_per_token": 0.00002, + "output_cost_per_token": 0.00006, + "litellm_provider": "openai", + "mode": "chat", + }, + } + ) assert litellm.model_cost["gpt-4"]["input_cost_per_token"] == 0.00002 except Exception as e: pytest.fail(f"An error occurred: {e}") + # test_update_model_cost() -def test_update_model_cost_map_url(): - try: - litellm.register_model(model_cost="https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json") + +def test_update_model_cost_map_url(): + try: + litellm.register_model( + model_cost="https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" + ) assert litellm.model_cost["gpt-4"]["input_cost_per_token"] == 0.00003 except Exception as e: pytest.fail(f"An error occurred: {e}") + # test_update_model_cost_map_url() -def test_update_model_cost_via_completion(): + +def test_update_model_cost_via_completion(): try: - response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}], input_cost_per_token=0.3, output_cost_per_token=0.4) - print(f"litellm.model_cost for gpt-3.5-turbo: {litellm.model_cost['gpt-3.5-turbo']}") + response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + input_cost_per_token=0.3, + output_cost_per_token=0.4, + ) + print( + f"litellm.model_cost for gpt-3.5-turbo: {litellm.model_cost['gpt-3.5-turbo']}" + ) assert litellm.model_cost["gpt-3.5-turbo"]["input_cost_per_token"] == 0.3 assert litellm.model_cost["gpt-3.5-turbo"]["output_cost_per_token"] == 0.4 - except Exception as e: + except Exception as e: pytest.fail(f"An error occurred: {e}") -test_update_model_cost_via_completion() \ No newline at end of file + +test_update_model_cost_via_completion() diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 81440c257..3b8ea7ed4 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -1,9 +1,10 @@ #### What this tests #### -#This tests litellm router +# This tests litellm router import sys, os, time import traceback, asyncio import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path @@ -12,371 +13,358 @@ from litellm import Router from concurrent.futures import ThreadPoolExecutor from collections import defaultdict from dotenv import load_dotenv + load_dotenv() + def test_exception_raising(): - # this tests if the router raises an exception when invalid params are set - # in this test both deployments have bad keys - Keep this test. It validates if the router raises the most recent exception - litellm.set_verbose=True - import openai - try: - print("testing if router raises an exception") - old_api_key = os.environ["AZURE_API_KEY"] - os.environ["AZURE_API_KEY"] = "" - model_list = [ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": "bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # - "model": "gpt-3.5-turbo", - "api_key": "bad-key", - }, - "tpm": 240000, - "rpm": 1800 - } - ] - router = Router(model_list=model_list, - redis_host=os.getenv("REDIS_HOST"), - redis_password=os.getenv("REDIS_PASSWORD"), - redis_port=int(os.getenv("REDIS_PORT")), - routing_strategy="simple-shuffle", - set_verbose=False, - num_retries=1) # type: ignore - response = router.completion( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "hello this request will fail" - } - ] - ) - os.environ["AZURE_API_KEY"] = old_api_key - pytest.fail(f"Should have raised an Auth Error") - except openai.AuthenticationError: - print("Test Passed: Caught an OPENAI AUTH Error, Good job. This is what we needed!") - os.environ["AZURE_API_KEY"] = old_api_key - router.reset() - except Exception as e: - os.environ["AZURE_API_KEY"] = old_api_key - print("Got unexpected exception on router!", e) + # this tests if the router raises an exception when invalid params are set + # in this test both deployments have bad keys - Keep this test. It validates if the router raises the most recent exception + litellm.set_verbose = True + import openai + + try: + print("testing if router raises an exception") + old_api_key = os.environ["AZURE_API_KEY"] + os.environ["AZURE_API_KEY"] = "" + model_list = [ + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # + "model": "gpt-3.5-turbo", + "api_key": "bad-key", + }, + "tpm": 240000, + "rpm": 1800, + }, + ] + router = Router( + model_list=model_list, + redis_host=os.getenv("REDIS_HOST"), + redis_password=os.getenv("REDIS_PASSWORD"), + redis_port=int(os.getenv("REDIS_PORT")), + routing_strategy="simple-shuffle", + set_verbose=False, + num_retries=1, + ) # type: ignore + response = router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hello this request will fail"}], + ) + os.environ["AZURE_API_KEY"] = old_api_key + pytest.fail(f"Should have raised an Auth Error") + except openai.AuthenticationError: + print( + "Test Passed: Caught an OPENAI AUTH Error, Good job. This is what we needed!" + ) + os.environ["AZURE_API_KEY"] = old_api_key + router.reset() + except Exception as e: + os.environ["AZURE_API_KEY"] = old_api_key + print("Got unexpected exception on router!", e) + + # test_exception_raising() def test_reading_key_from_model_list(): - # [PROD TEST CASE] - # this tests if the router can read key from model list and make completion call, and completion + stream call. This is 90% of the router use case - # DO NOT REMOVE THIS TEST. It's an IMP ONE. Speak to Ishaan, if you are tring to remove this - litellm.set_verbose=False - import openai - try: - print("testing if router raises an exception") - old_api_key = os.environ["AZURE_API_KEY"] - os.environ.pop("AZURE_API_KEY", None) - model_list = [ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": old_api_key, - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - } - ] + # [PROD TEST CASE] + # this tests if the router can read key from model list and make completion call, and completion + stream call. This is 90% of the router use case + # DO NOT REMOVE THIS TEST. It's an IMP ONE. Speak to Ishaan, if you are tring to remove this + litellm.set_verbose = False + import openai + + try: + print("testing if router raises an exception") + old_api_key = os.environ["AZURE_API_KEY"] + os.environ.pop("AZURE_API_KEY", None) + model_list = [ + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": old_api_key, + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + } + ] + + router = Router( + model_list=model_list, + redis_host=os.getenv("REDIS_HOST"), + redis_password=os.getenv("REDIS_PASSWORD"), + redis_port=int(os.getenv("REDIS_PORT")), + routing_strategy="simple-shuffle", + set_verbose=True, + num_retries=1, + ) # type: ignore + response = router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hello this request will fail"}], + ) + print("\n response", response) + str_response = response.choices[0].message.content + print("\n str_response", str_response) + assert len(str_response) > 0 + + print("\n Testing streaming response") + response = router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hello this request will fail"}], + stream=True, + ) + completed_response = "" + for chunk in response: + if chunk is not None: + print(chunk) + completed_response += chunk.choices[0].delta.content or "" + print("\n completed_response", completed_response) + assert len(completed_response) > 0 + print("\n Passed Streaming") + os.environ["AZURE_API_KEY"] = old_api_key + router.reset() + except Exception as e: + os.environ["AZURE_API_KEY"] = old_api_key + print(f"FAILED TEST") + pytest.fail(f"Got unexpected exception on router! - {e}") - router = Router(model_list=model_list, - redis_host=os.getenv("REDIS_HOST"), - redis_password=os.getenv("REDIS_PASSWORD"), - redis_port=int(os.getenv("REDIS_PORT")), - routing_strategy="simple-shuffle", - set_verbose=True, - num_retries=1) # type: ignore - response = router.completion( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "hello this request will fail" - } - ] - ) - print("\n response", response) - str_response = response.choices[0].message.content - print("\n str_response", str_response) - assert len(str_response) > 0 - print("\n Testing streaming response") - response = router.completion( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "hello this request will fail" - } - ], - stream=True - ) - completed_response = "" - for chunk in response: - if chunk is not None: - print(chunk) - completed_response += chunk.choices[0].delta.content or "" - print("\n completed_response", completed_response) - assert len(completed_response) > 0 - print("\n Passed Streaming") - os.environ["AZURE_API_KEY"] = old_api_key - router.reset() - except Exception as e: - os.environ["AZURE_API_KEY"] = old_api_key - print(f"FAILED TEST") - pytest.fail(f"Got unexpected exception on router! - {e}") # test_reading_key_from_model_list() + def test_call_one_endpoint(): - # [PROD TEST CASE] - # user passes one deployment they want to call on the router, we call the specified one - # this test makes a completion calls azure/chatgpt-v-2, it should work - try: - print("Testing calling a specific deployment") - old_api_key = os.environ["AZURE_API_KEY"] + # [PROD TEST CASE] + # user passes one deployment they want to call on the router, we call the specified one + # this test makes a completion calls azure/chatgpt-v-2, it should work + try: + print("Testing calling a specific deployment") + old_api_key = os.environ["AZURE_API_KEY"] - model_list = [ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": old_api_key, - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "claude-v1", - "litellm_params": { - "model": "bedrock/anthropic.claude-instant-v1", - }, - "tpm": 100000, - "rpm": 10000, - }, - { - "model_name": "text-embedding-ada-002", - "litellm_params": { - "model": "azure/azure-embedding-model", - "api_key":os.environ['AZURE_API_KEY'], - "api_base": os.environ['AZURE_API_BASE'] - }, - "tpm": 100000, - "rpm": 10000, - }, - ] - litellm.set_verbose=True - router = Router(model_list=model_list, - routing_strategy="simple-shuffle", - set_verbose=True, - num_retries=1) # type: ignore - old_api_base = os.environ.pop("AZURE_API_BASE", None) + model_list = [ + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": old_api_key, + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "claude-v1", + "litellm_params": { + "model": "bedrock/anthropic.claude-instant-v1", + }, + "tpm": 100000, + "rpm": 10000, + }, + { + "model_name": "text-embedding-ada-002", + "litellm_params": { + "model": "azure/azure-embedding-model", + "api_key": os.environ["AZURE_API_KEY"], + "api_base": os.environ["AZURE_API_BASE"], + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + litellm.set_verbose = True + router = Router( + model_list=model_list, + routing_strategy="simple-shuffle", + set_verbose=True, + num_retries=1, + ) # type: ignore + old_api_base = os.environ.pop("AZURE_API_BASE", None) - async def call_azure_completion(): - response = await router.acompletion( - model="azure/chatgpt-v-2", - messages=[ - { - "role": "user", - "content": "hello this request will pass" - } - ], - specific_deployment=True - ) - print("\n response", response) + async def call_azure_completion(): + response = await router.acompletion( + model="azure/chatgpt-v-2", + messages=[{"role": "user", "content": "hello this request will pass"}], + specific_deployment=True, + ) + print("\n response", response) - async def call_bedrock_claude(): - response = await router.acompletion( - model="bedrock/anthropic.claude-instant-v1", - messages=[ - { - "role": "user", - "content": "hello this request will pass" - } - ], - specific_deployment=True - ) + async def call_bedrock_claude(): + response = await router.acompletion( + model="bedrock/anthropic.claude-instant-v1", + messages=[{"role": "user", "content": "hello this request will pass"}], + specific_deployment=True, + ) - print("\n response", response) - - async def call_azure_embedding(): - response = await router.aembedding( - model="azure/azure-embedding-model", - input = ["good morning from litellm"], - specific_deployment=True - ) + print("\n response", response) + + async def call_azure_embedding(): + response = await router.aembedding( + model="azure/azure-embedding-model", + input=["good morning from litellm"], + specific_deployment=True, + ) + + print("\n response", response) + + asyncio.run(call_azure_completion()) + asyncio.run(call_bedrock_claude()) + asyncio.run(call_azure_embedding()) + + os.environ["AZURE_API_BASE"] = old_api_base + os.environ["AZURE_API_KEY"] = old_api_key + except Exception as e: + print(f"FAILED TEST") + pytest.fail(f"Got unexpected exception on router! - {e}") - print("\n response", response) - asyncio.run(call_azure_completion()) - asyncio.run(call_bedrock_claude()) - asyncio.run(call_azure_embedding()) - - os.environ["AZURE_API_BASE"] = old_api_base - os.environ["AZURE_API_KEY"] = old_api_key - except Exception as e: - print(f"FAILED TEST") - pytest.fail(f"Got unexpected exception on router! - {e}") # test_call_one_endpoint() - def test_router_azure_acompletion(): - # [PROD TEST CASE] - # This is 90% of the router use case, makes an acompletion call, acompletion + stream call and verifies it got a response - # DO NOT REMOVE THIS TEST. It's an IMP ONE. Speak to Ishaan, if you are tring to remove this - litellm.set_verbose=False - import openai - try: - print("Router Test Azure - Acompletion, Acompletion with stream") + # [PROD TEST CASE] + # This is 90% of the router use case, makes an acompletion call, acompletion + stream call and verifies it got a response + # DO NOT REMOVE THIS TEST. It's an IMP ONE. Speak to Ishaan, if you are tring to remove this + litellm.set_verbose = False + import openai - # remove api key from env to repro how proxy passes key to router - old_api_key = os.environ["AZURE_API_KEY"] - os.environ.pop("AZURE_API_KEY", None) - - model_list = [ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": old_api_key, - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "rpm": 1800 - }, - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/gpt-turbo", - "api_key": os.getenv("AZURE_FRANCE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": "https://openai-france-1234.openai.azure.com" - }, - "rpm": 1800 - } - ] + try: + print("Router Test Azure - Acompletion, Acompletion with stream") - router = Router(model_list=model_list, - routing_strategy="simple-shuffle", - set_verbose=True - ) # type: ignore - - async def test1(): + # remove api key from env to repro how proxy passes key to router + old_api_key = os.environ["AZURE_API_KEY"] + os.environ.pop("AZURE_API_KEY", None) + + model_list = [ + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": old_api_key, + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/gpt-turbo", + "api_key": os.getenv("AZURE_FRANCE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": "https://openai-france-1234.openai.azure.com", + }, + "rpm": 1800, + }, + ] + + router = Router( + model_list=model_list, routing_strategy="simple-shuffle", set_verbose=True + ) # type: ignore + + async def test1(): + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hello this request will pass"}], + ) + str_response = response.choices[0].message.content + print("\n str_response", str_response) + assert len(str_response) > 0 + print("\n response", response) + + asyncio.run(test1()) + + print("\n Testing streaming response") + + async def test2(): + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hello this request will fail"}], + stream=True, + ) + completed_response = "" + async for chunk in response: + if chunk is not None: + print(chunk) + completed_response += chunk.choices[0].delta.content or "" + print("\n completed_response", completed_response) + assert len(completed_response) > 0 + + asyncio.run(test2()) + print("\n Passed Streaming") + os.environ["AZURE_API_KEY"] = old_api_key + router.reset() + except Exception as e: + os.environ["AZURE_API_KEY"] = old_api_key + print(f"FAILED TEST") + pytest.fail(f"Got unexpected exception on router! - {e}") - response = await router.acompletion( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "hello this request will pass" - } - ] - ) - str_response = response.choices[0].message.content - print("\n str_response", str_response) - assert len(str_response) > 0 - print("\n response", response) - asyncio.run(test1()) - print("\n Testing streaming response") - async def test2(): - response = await router.acompletion( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "hello this request will fail" - } - ], - stream=True - ) - completed_response = "" - async for chunk in response: - if chunk is not None: - print(chunk) - completed_response += chunk.choices[0].delta.content or "" - print("\n completed_response", completed_response) - assert len(completed_response) > 0 - asyncio.run(test2()) - print("\n Passed Streaming") - os.environ["AZURE_API_KEY"] = old_api_key - router.reset() - except Exception as e: - os.environ["AZURE_API_KEY"] = old_api_key - print(f"FAILED TEST") - pytest.fail(f"Got unexpected exception on router! - {e}") # test_router_azure_acompletion() -### FUNCTION CALLING +### FUNCTION CALLING -def test_function_calling(): - model_list = [ - { - "model_name": "gpt-3.5-turbo-0613", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "tpm": 100000, - "rpm": 10000, - }, - ] - messages = [ - {"role": "user", "content": "What is the weather like in Boston?"} - ] - functions = [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } - }, - "required": ["location"] - } - } - ] +def test_function_calling(): + model_list = [ + { + "model_name": "gpt-3.5-turbo-0613", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 100000, + "rpm": 10000, + }, + ] - router = Router(model_list=model_list) - response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions) - router.reset() - print(response) + messages = [{"role": "user", "content": "What is the weather like in Boston?"}] + functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] -# test_acompletion_on_router() + router = Router(model_list=model_list) + response = router.completion( + model="gpt-3.5-turbo-0613", messages=messages, functions=functions + ) + router.reset() + print(response) -def test_function_calling_on_router(): - try: - litellm.set_verbose = True - model_list = [ + +# test_acompletion_on_router() + + +def test_function_calling_on_router(): + try: + litellm.set_verbose = True + model_list = [ { "model_name": "gpt-3.5-turbo", "litellm_params": { @@ -385,7 +373,7 @@ def test_function_calling_on_router(): }, }, ] - function1 = [ + function1 = [ { "name": "get_current_weather", "description": "Get the current weather in a given location", @@ -402,470 +390,541 @@ def test_function_calling_on_router(): }, } ] - router = Router( - model_list=model_list, - redis_host=os.getenv("REDIS_HOST"), - redis_password=os.getenv("REDIS_PASSWORD"), - redis_port=os.getenv("REDIS_PORT") - ) - messages=[ - { - "role": "user", - "content": "what's the weather in boston" - } - ] - response = router.completion(model="gpt-3.5-turbo", messages=messages, functions=function1) - print(f"final returned response: {response}") - router.reset() - assert isinstance(response["choices"][0]["message"]["function_call"], dict) - except Exception as e: - print(f"An exception occurred: {e}") + router = Router( + model_list=model_list, + redis_host=os.getenv("REDIS_HOST"), + redis_password=os.getenv("REDIS_PASSWORD"), + redis_port=os.getenv("REDIS_PORT"), + ) + messages = [{"role": "user", "content": "what's the weather in boston"}] + response = router.completion( + model="gpt-3.5-turbo", messages=messages, functions=function1 + ) + print(f"final returned response: {response}") + router.reset() + assert isinstance(response["choices"][0]["message"]["function_call"], dict) + except Exception as e: + print(f"An exception occurred: {e}") + # test_function_calling_on_router() -### IMAGE GENERATION + +### IMAGE GENERATION @pytest.mark.asyncio async def test_aimg_gen_on_router(): - litellm.set_verbose = True - try: - model_list = [ - { - "model_name": "dall-e-3", - "litellm_params": { - "model": "dall-e-3", - }, - }, - { - "model_name": "dall-e-3", - "litellm_params": { - "model": "azure/dall-e-3-test", - "api_version": "2023-12-01-preview", - "api_base": os.getenv("AZURE_SWEDEN_API_BASE"), - "api_key": os.getenv("AZURE_SWEDEN_API_KEY") - } - }, - { - "model_name": "dall-e-2", - "litellm_params": { - "model": "azure/", - "api_version": "2023-06-01-preview", - "api_base": os.getenv("AZURE_API_BASE"), - "api_key": os.getenv("AZURE_API_KEY") - } - } - ] - router = Router(model_list=model_list) - response = await router.aimage_generation( - model="dall-e-3", - prompt="A cute baby sea otter" - ) - print(response) - assert len(response.data) > 0 + litellm.set_verbose = True + try: + model_list = [ + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "dall-e-3", + }, + }, + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "azure/dall-e-3-test", + "api_version": "2023-12-01-preview", + "api_base": os.getenv("AZURE_SWEDEN_API_BASE"), + "api_key": os.getenv("AZURE_SWEDEN_API_KEY"), + }, + }, + { + "model_name": "dall-e-2", + "litellm_params": { + "model": "azure/", + "api_version": "2023-06-01-preview", + "api_base": os.getenv("AZURE_API_BASE"), + "api_key": os.getenv("AZURE_API_KEY"), + }, + }, + ] + router = Router(model_list=model_list) + response = await router.aimage_generation( + model="dall-e-3", prompt="A cute baby sea otter" + ) + print(response) + assert len(response.data) > 0 + + response = await router.aimage_generation( + model="dall-e-2", prompt="A cute baby sea otter" + ) + print(response) + assert len(response.data) > 0 + + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") - response = await router.aimage_generation( - model="dall-e-2", - prompt="A cute baby sea otter" - ) - print(response) - assert len(response.data) > 0 - - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") # asyncio.run(test_aimg_gen_on_router()) + def test_img_gen_on_router(): - litellm.set_verbose = True - try: - model_list = [ - { - "model_name": "dall-e-3", - "litellm_params": { - "model": "dall-e-3", - }, - }, - { - "model_name": "dall-e-3", - "litellm_params": { - "model": "azure/dall-e-3-test", - "api_version": "2023-12-01-preview", - "api_base": os.getenv("AZURE_SWEDEN_API_BASE"), - "api_key": os.getenv("AZURE_SWEDEN_API_KEY") - } - } - ] - router = Router(model_list=model_list) - response = router.image_generation( - model="dall-e-3", - prompt="A cute baby sea otter" - ) - print(response) - assert len(response.data) > 0 - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") + litellm.set_verbose = True + try: + model_list = [ + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "dall-e-3", + }, + }, + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "azure/dall-e-3-test", + "api_version": "2023-12-01-preview", + "api_base": os.getenv("AZURE_SWEDEN_API_BASE"), + "api_key": os.getenv("AZURE_SWEDEN_API_KEY"), + }, + }, + ] + router = Router(model_list=model_list) + response = router.image_generation( + model="dall-e-3", prompt="A cute baby sea otter" + ) + print(response) + assert len(response.data) > 0 + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + # test_img_gen_on_router() ### -def test_aembedding_on_router(): - litellm.set_verbose = True - try: - model_list = [ - { - "model_name": "text-embedding-ada-002", - "litellm_params": { - "model": "text-embedding-ada-002", - }, - "tpm": 100000, - "rpm": 10000, - }, - ] - router = Router(model_list=model_list) - async def embedding_call(): - response = await router.aembedding( - model="text-embedding-ada-002", - input=["good morning from litellm", "this is another item"], - ) - print(response) - asyncio.run(embedding_call()) - print("\n Making sync Embedding call\n") - response = router.embedding( - model="text-embedding-ada-002", - input=["good morning from litellm 2"], - ) - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") +def test_aembedding_on_router(): + litellm.set_verbose = True + try: + model_list = [ + { + "model_name": "text-embedding-ada-002", + "litellm_params": { + "model": "text-embedding-ada-002", + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + router = Router(model_list=model_list) + + async def embedding_call(): + response = await router.aembedding( + model="text-embedding-ada-002", + input=["good morning from litellm", "this is another item"], + ) + print(response) + + asyncio.run(embedding_call()) + + print("\n Making sync Embedding call\n") + response = router.embedding( + model="text-embedding-ada-002", + input=["good morning from litellm 2"], + ) + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + + # test_aembedding_on_router() def test_azure_embedding_on_router(): - """ - [PROD Use Case] - Makes an aembedding call + embedding call - """ - litellm.set_verbose = True - try: - model_list = [ - { - "model_name": "text-embedding-ada-002", - "litellm_params": { - "model": "azure/azure-embedding-model", - "api_key":os.environ['AZURE_API_KEY'], - "api_base": os.environ['AZURE_API_BASE'] - }, - "tpm": 100000, - "rpm": 10000, - }, - ] - router = Router(model_list=model_list) + """ + [PROD Use Case] - Makes an aembedding call + embedding call + """ + litellm.set_verbose = True + try: + model_list = [ + { + "model_name": "text-embedding-ada-002", + "litellm_params": { + "model": "azure/azure-embedding-model", + "api_key": os.environ["AZURE_API_KEY"], + "api_base": os.environ["AZURE_API_BASE"], + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + router = Router(model_list=model_list) - async def embedding_call(): - response = await router.aembedding( - model="text-embedding-ada-002", - input=["good morning from litellm"] - ) - print(response) - asyncio.run(embedding_call()) + async def embedding_call(): + response = await router.aembedding( + model="text-embedding-ada-002", input=["good morning from litellm"] + ) + print(response) + + asyncio.run(embedding_call()) + + print("\n Making sync Azure Embedding call\n") + + response = router.embedding( + model="text-embedding-ada-002", + input=["test 2 from litellm. async embedding"], + ) + print(response) + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") - print("\n Making sync Azure Embedding call\n") - response = router.embedding( - model="text-embedding-ada-002", - input=["test 2 from litellm. async embedding"] - ) - print(response) - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") # test_azure_embedding_on_router() def test_bedrock_on_router(): - litellm.set_verbose = True - print("\n Testing bedrock on router\n") - try: - model_list = [ - { - "model_name": "claude-v1", - "litellm_params": { - "model": "bedrock/anthropic.claude-instant-v1", - }, - "tpm": 100000, - "rpm": 10000, - }, - ] + litellm.set_verbose = True + print("\n Testing bedrock on router\n") + try: + model_list = [ + { + "model_name": "claude-v1", + "litellm_params": { + "model": "bedrock/anthropic.claude-instant-v1", + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + + async def test(): + router = Router(model_list=model_list) + response = await router.acompletion( + model="claude-v1", + messages=[ + { + "role": "user", + "content": "hello from litellm test", + } + ], + ) + print(response) + router.reset() + + asyncio.run(test()) + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + - async def test(): - router = Router(model_list=model_list) - response = await router.acompletion( - model="claude-v1", - messages=[ - { - "role": "user", - "content": "hello from litellm test", - } - ] - ) - print(response) - router.reset() - asyncio.run(test()) - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") # test_bedrock_on_router() + # test openai-compatible endpoint @pytest.mark.asyncio async def test_mistral_on_router(): - litellm.set_verbose = True - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "mistral/mistral-medium", - }, - }, - ] - router = Router(model_list=model_list) - response = await router.acompletion( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "hello from litellm test", - } - ] - ) - print(response) + litellm.set_verbose = True + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "mistral/mistral-medium", + }, + }, + ] + router = Router(model_list=model_list) + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello from litellm test", + } + ], + ) + print(response) + + # asyncio.run(test_mistral_on_router()) + def test_openai_completion_on_router(): - # [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream - # 4 LLM API calls made here. If it fails, add retries. Do not remove this test. - litellm.set_verbose = True - print("\n Testing OpenAI on router\n") - try: - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo", - }, - }, - ] - router = Router(model_list=model_list) + # [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream + # 4 LLM API calls made here. If it fails, add retries. Do not remove this test. + litellm.set_verbose = True + print("\n Testing OpenAI on router\n") + try: + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + }, + }, + ] + router = Router(model_list=model_list) - async def test(): - response = await router.acompletion( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "hello from litellm test", - } - ] - ) - print(response) - assert len(response.choices[0].message.content) > 0 + async def test(): + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello from litellm test", + } + ], + ) + print(response) + assert len(response.choices[0].message.content) > 0 + + print("\n streaming + acompletion test") + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": f"hello from litellm test {time.time()}", + } + ], + stream=True, + ) + complete_response = "" + print(response) + # if you want to see all the attributes and methods + async for chunk in response: + print(chunk) + complete_response += chunk.choices[0].delta.content or "" + print("\n complete response: ", complete_response) + assert len(complete_response) > 0 + + asyncio.run(test()) + print("\n Testing Sync completion calls \n") + response = router.completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello from litellm test2", + } + ], + ) + print(response) + assert len(response.choices[0].message.content) > 0 + + print("\n streaming + completion test") + response = router.completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello from litellm test3", + } + ], + stream=True, + ) + complete_response = "" + print(response) + for chunk in response: + print(chunk) + complete_response += chunk.choices[0].delta.content or "" + print("\n complete response: ", complete_response) + assert len(complete_response) > 0 + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") - print("\n streaming + acompletion test") - response = await router.acompletion( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": f"hello from litellm test {time.time()}", - } - ], - stream=True - ) - complete_response = "" - print(response) - # if you want to see all the attributes and methods - async for chunk in response: - print(chunk) - complete_response += chunk.choices[0].delta.content or "" - print("\n complete response: ", complete_response) - assert len(complete_response) > 0 - - asyncio.run(test()) - print("\n Testing Sync completion calls \n") - response = router.completion( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "hello from litellm test2", - } - ] - ) - print(response) - assert len(response.choices[0].message.content) > 0 - print("\n streaming + completion test") - response = router.completion( - model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "hello from litellm test3", - } - ], - stream=True - ) - complete_response = "" - print(response) - for chunk in response: - print(chunk) - complete_response += chunk.choices[0].delta.content or "" - print("\n complete response: ", complete_response) - assert len(complete_response) > 0 - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") # test_openai_completion_on_router() def test_reading_keys_os_environ(): - import openai - try: - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo", - "api_key": "os.environ/AZURE_API_KEY", - "api_base": "os.environ/AZURE_API_BASE", - "api_version": "os.environ/AZURE_API_VERSION", - "timeout": "os.environ/AZURE_TIMEOUT", - "stream_timeout": "os.environ/AZURE_STREAM_TIMEOUT", - "max_retries": "os.environ/AZURE_MAX_RETRIES", - }, - }, - ] - - router = Router(model_list=model_list) - for model in router.model_list: - assert model["litellm_params"]["api_key"] == os.environ["AZURE_API_KEY"], f"{model['litellm_params']['api_key']} vs {os.environ['AZURE_API_KEY']}" - assert model["litellm_params"]["api_base"] == os.environ["AZURE_API_BASE"], f"{model['litellm_params']['api_base']} vs {os.environ['AZURE_API_BASE']}" - assert model["litellm_params"]["api_version"] == os.environ["AZURE_API_VERSION"], f"{model['litellm_params']['api_version']} vs {os.environ['AZURE_API_VERSION']}" - assert float(model["litellm_params"]["timeout"]) == float(os.environ["AZURE_TIMEOUT"]), f"{model['litellm_params']['timeout']} vs {os.environ['AZURE_TIMEOUT']}" - assert float(model["litellm_params"]["stream_timeout"]) == float(os.environ["AZURE_STREAM_TIMEOUT"]), f"{model['litellm_params']['stream_timeout']} vs {os.environ['AZURE_STREAM_TIMEOUT']}" - assert int(model["litellm_params"]["max_retries"]) == int(os.environ["AZURE_MAX_RETRIES"]), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}" - print("passed testing of reading keys from os.environ") - async_client: openai.AsyncAzureOpenAI = model["async_client"] # type: ignore - assert async_client.api_key == os.environ["AZURE_API_KEY"] - assert async_client.base_url == os.environ["AZURE_API_BASE"] - assert async_client.max_retries == (os.environ["AZURE_MAX_RETRIES"]), f"{async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" - assert async_client.timeout == (os.environ["AZURE_TIMEOUT"]), f"{async_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" - print("async client set correctly!") + import openai - print("\n Testing async streaming client") + try: + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "os.environ/AZURE_API_KEY", + "api_base": "os.environ/AZURE_API_BASE", + "api_version": "os.environ/AZURE_API_VERSION", + "timeout": "os.environ/AZURE_TIMEOUT", + "stream_timeout": "os.environ/AZURE_STREAM_TIMEOUT", + "max_retries": "os.environ/AZURE_MAX_RETRIES", + }, + }, + ] - stream_async_client: openai.AsyncAzureOpenAI = model["stream_async_client"] # type: ignore - assert stream_async_client.api_key == os.environ["AZURE_API_KEY"] - assert stream_async_client.base_url == os.environ["AZURE_API_BASE"] - assert stream_async_client.max_retries == (os.environ["AZURE_MAX_RETRIES"]), f"{stream_async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" - assert stream_async_client.timeout == (os.environ["AZURE_STREAM_TIMEOUT"]), f"{stream_async_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" - print("async stream client set correctly!") + router = Router(model_list=model_list) + for model in router.model_list: + assert ( + model["litellm_params"]["api_key"] == os.environ["AZURE_API_KEY"] + ), f"{model['litellm_params']['api_key']} vs {os.environ['AZURE_API_KEY']}" + assert ( + model["litellm_params"]["api_base"] == os.environ["AZURE_API_BASE"] + ), f"{model['litellm_params']['api_base']} vs {os.environ['AZURE_API_BASE']}" + assert ( + model["litellm_params"]["api_version"] + == os.environ["AZURE_API_VERSION"] + ), f"{model['litellm_params']['api_version']} vs {os.environ['AZURE_API_VERSION']}" + assert float(model["litellm_params"]["timeout"]) == float( + os.environ["AZURE_TIMEOUT"] + ), f"{model['litellm_params']['timeout']} vs {os.environ['AZURE_TIMEOUT']}" + assert float(model["litellm_params"]["stream_timeout"]) == float( + os.environ["AZURE_STREAM_TIMEOUT"] + ), f"{model['litellm_params']['stream_timeout']} vs {os.environ['AZURE_STREAM_TIMEOUT']}" + assert int(model["litellm_params"]["max_retries"]) == int( + os.environ["AZURE_MAX_RETRIES"] + ), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}" + print("passed testing of reading keys from os.environ") + async_client: openai.AsyncAzureOpenAI = model["async_client"] # type: ignore + assert async_client.api_key == os.environ["AZURE_API_KEY"] + assert async_client.base_url == os.environ["AZURE_API_BASE"] + assert async_client.max_retries == ( + os.environ["AZURE_MAX_RETRIES"] + ), f"{async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" + assert async_client.timeout == ( + os.environ["AZURE_TIMEOUT"] + ), f"{async_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" + print("async client set correctly!") - print("\n Testing sync client") - client: openai.AzureOpenAI = model["client"] # type: ignore - assert client.api_key == os.environ["AZURE_API_KEY"] - assert client.base_url == os.environ["AZURE_API_BASE"] - assert client.max_retries == (os.environ["AZURE_MAX_RETRIES"]), f"{client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" - assert client.timeout == (os.environ["AZURE_TIMEOUT"]), f"{client.timeout} vs {os.environ['AZURE_TIMEOUT']}" - print("sync client set correctly!") + print("\n Testing async streaming client") - print("\n Testing sync stream client") - stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore - assert stream_client.api_key == os.environ["AZURE_API_KEY"] - assert stream_client.base_url == os.environ["AZURE_API_BASE"] - assert stream_client.max_retries == (os.environ["AZURE_MAX_RETRIES"]), f"{stream_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" - assert stream_client.timeout == (os.environ["AZURE_STREAM_TIMEOUT"]), f"{stream_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" - print("sync stream client set correctly!") + stream_async_client: openai.AsyncAzureOpenAI = model["stream_async_client"] # type: ignore + assert stream_async_client.api_key == os.environ["AZURE_API_KEY"] + assert stream_async_client.base_url == os.environ["AZURE_API_BASE"] + assert stream_async_client.max_retries == ( + os.environ["AZURE_MAX_RETRIES"] + ), f"{stream_async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" + assert stream_async_client.timeout == ( + os.environ["AZURE_STREAM_TIMEOUT"] + ), f"{stream_async_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" + print("async stream client set correctly!") + + print("\n Testing sync client") + client: openai.AzureOpenAI = model["client"] # type: ignore + assert client.api_key == os.environ["AZURE_API_KEY"] + assert client.base_url == os.environ["AZURE_API_BASE"] + assert client.max_retries == ( + os.environ["AZURE_MAX_RETRIES"] + ), f"{client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" + assert client.timeout == ( + os.environ["AZURE_TIMEOUT"] + ), f"{client.timeout} vs {os.environ['AZURE_TIMEOUT']}" + print("sync client set correctly!") + + print("\n Testing sync stream client") + stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore + assert stream_client.api_key == os.environ["AZURE_API_KEY"] + assert stream_client.base_url == os.environ["AZURE_API_BASE"] + assert stream_client.max_retries == ( + os.environ["AZURE_MAX_RETRIES"] + ), f"{stream_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" + assert stream_client.timeout == ( + os.environ["AZURE_STREAM_TIMEOUT"] + ), f"{stream_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" + print("sync stream client set correctly!") + + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") # test_reading_keys_os_environ() def test_reading_openai_keys_os_environ(): - import openai - try: - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo", - "api_key": "os.environ/OPENAI_API_KEY", - "timeout": "os.environ/AZURE_TIMEOUT", - "stream_timeout": "os.environ/AZURE_STREAM_TIMEOUT", - "max_retries": "os.environ/AZURE_MAX_RETRIES", - }, - }, - { - "model_name": "text-embedding-ada-002", - "litellm_params": { - "model": "text-embedding-ada-002", - "api_key": "os.environ/OPENAI_API_KEY", - "timeout": "os.environ/AZURE_TIMEOUT", - "stream_timeout": "os.environ/AZURE_STREAM_TIMEOUT", - "max_retries": "os.environ/AZURE_MAX_RETRIES", - }, - }, + import openai - ] - - router = Router(model_list=model_list) - for model in router.model_list: - assert model["litellm_params"]["api_key"] == os.environ["OPENAI_API_KEY"], f"{model['litellm_params']['api_key']} vs {os.environ['AZURE_API_KEY']}" - assert float(model["litellm_params"]["timeout"]) == float(os.environ["AZURE_TIMEOUT"]), f"{model['litellm_params']['timeout']} vs {os.environ['AZURE_TIMEOUT']}" - assert float(model["litellm_params"]["stream_timeout"]) == float(os.environ["AZURE_STREAM_TIMEOUT"]), f"{model['litellm_params']['stream_timeout']} vs {os.environ['AZURE_STREAM_TIMEOUT']}" - assert int(model["litellm_params"]["max_retries"]) == int(os.environ["AZURE_MAX_RETRIES"]), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}" - print("passed testing of reading keys from os.environ") - async_client: openai.AsyncOpenAI = model["async_client"] # type: ignore - assert async_client.api_key == os.environ["OPENAI_API_KEY"] - assert async_client.max_retries == (os.environ["AZURE_MAX_RETRIES"]), f"{async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" - assert async_client.timeout == (os.environ["AZURE_TIMEOUT"]), f"{async_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" - print("async client set correctly!") + try: + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "os.environ/OPENAI_API_KEY", + "timeout": "os.environ/AZURE_TIMEOUT", + "stream_timeout": "os.environ/AZURE_STREAM_TIMEOUT", + "max_retries": "os.environ/AZURE_MAX_RETRIES", + }, + }, + { + "model_name": "text-embedding-ada-002", + "litellm_params": { + "model": "text-embedding-ada-002", + "api_key": "os.environ/OPENAI_API_KEY", + "timeout": "os.environ/AZURE_TIMEOUT", + "stream_timeout": "os.environ/AZURE_STREAM_TIMEOUT", + "max_retries": "os.environ/AZURE_MAX_RETRIES", + }, + }, + ] - print("\n Testing async streaming client") + router = Router(model_list=model_list) + for model in router.model_list: + assert ( + model["litellm_params"]["api_key"] == os.environ["OPENAI_API_KEY"] + ), f"{model['litellm_params']['api_key']} vs {os.environ['AZURE_API_KEY']}" + assert float(model["litellm_params"]["timeout"]) == float( + os.environ["AZURE_TIMEOUT"] + ), f"{model['litellm_params']['timeout']} vs {os.environ['AZURE_TIMEOUT']}" + assert float(model["litellm_params"]["stream_timeout"]) == float( + os.environ["AZURE_STREAM_TIMEOUT"] + ), f"{model['litellm_params']['stream_timeout']} vs {os.environ['AZURE_STREAM_TIMEOUT']}" + assert int(model["litellm_params"]["max_retries"]) == int( + os.environ["AZURE_MAX_RETRIES"] + ), f"{model['litellm_params']['max_retries']} vs {os.environ['AZURE_MAX_RETRIES']}" + print("passed testing of reading keys from os.environ") + async_client: openai.AsyncOpenAI = model["async_client"] # type: ignore + assert async_client.api_key == os.environ["OPENAI_API_KEY"] + assert async_client.max_retries == ( + os.environ["AZURE_MAX_RETRIES"] + ), f"{async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" + assert async_client.timeout == ( + os.environ["AZURE_TIMEOUT"] + ), f"{async_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" + print("async client set correctly!") - stream_async_client: openai.AsyncOpenAI = model["stream_async_client"] # type: ignore - assert stream_async_client.api_key == os.environ["OPENAI_API_KEY"] - assert stream_async_client.max_retries == (os.environ["AZURE_MAX_RETRIES"]), f"{stream_async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" - assert stream_async_client.timeout == (os.environ["AZURE_STREAM_TIMEOUT"]), f"{stream_async_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" - print("async stream client set correctly!") + print("\n Testing async streaming client") - print("\n Testing sync client") - client: openai.AzureOpenAI = model["client"] # type: ignore - assert client.api_key == os.environ["OPENAI_API_KEY"] - assert client.max_retries == (os.environ["AZURE_MAX_RETRIES"]), f"{client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" - assert client.timeout == (os.environ["AZURE_TIMEOUT"]), f"{client.timeout} vs {os.environ['AZURE_TIMEOUT']}" - print("sync client set correctly!") + stream_async_client: openai.AsyncOpenAI = model["stream_async_client"] # type: ignore + assert stream_async_client.api_key == os.environ["OPENAI_API_KEY"] + assert stream_async_client.max_retries == ( + os.environ["AZURE_MAX_RETRIES"] + ), f"{stream_async_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" + assert stream_async_client.timeout == ( + os.environ["AZURE_STREAM_TIMEOUT"] + ), f"{stream_async_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" + print("async stream client set correctly!") - print("\n Testing sync stream client") - stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore - assert stream_client.api_key == os.environ["OPENAI_API_KEY"] - assert stream_client.max_retries == (os.environ["AZURE_MAX_RETRIES"]), f"{stream_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" - assert stream_client.timeout == (os.environ["AZURE_STREAM_TIMEOUT"]), f"{stream_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" - print("sync stream client set correctly!") + print("\n Testing sync client") + client: openai.AzureOpenAI = model["client"] # type: ignore + assert client.api_key == os.environ["OPENAI_API_KEY"] + assert client.max_retries == ( + os.environ["AZURE_MAX_RETRIES"] + ), f"{client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" + assert client.timeout == ( + os.environ["AZURE_TIMEOUT"] + ), f"{client.timeout} vs {os.environ['AZURE_TIMEOUT']}" + print("sync client set correctly!") - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") + print("\n Testing sync stream client") + stream_client: openai.AzureOpenAI = model["stream_client"] # type: ignore + assert stream_client.api_key == os.environ["OPENAI_API_KEY"] + assert stream_client.max_retries == ( + os.environ["AZURE_MAX_RETRIES"] + ), f"{stream_client.max_retries} vs {os.environ['AZURE_MAX_RETRIES']}" + assert stream_client.timeout == ( + os.environ["AZURE_STREAM_TIMEOUT"] + ), f"{stream_client.timeout} vs {os.environ['AZURE_TIMEOUT']}" + print("sync stream client set correctly!") -# test_reading_openai_keys_os_environ() \ No newline at end of file + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + + +# test_reading_openai_keys_os_environ() diff --git a/litellm/tests/test_router_caching.py b/litellm/tests/test_router_caching.py index 006a5b50c..67c263aa2 100644 --- a/litellm/tests/test_router_caching.py +++ b/litellm/tests/test_router_caching.py @@ -3,164 +3,184 @@ import sys, os, time import traceback, asyncio import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import litellm from litellm import Router -## Scenarios +## Scenarios ## 1. 2 models - openai + azure - 1 model group "gpt-3.5-turbo", ## 2. 2 models - openai, azure - 2 diff model groups, 1 caching group -@pytest.mark.asyncio -async def test_acompletion_caching_on_router(): - # tests acompletion + caching on router - try: - litellm.set_verbose = True - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "tpm": 100000, - "rpm": 10000, - }, - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION") - }, - "tpm": 100000, - "rpm": 10000, - } - ] - - messages = [ - {"role": "user", "content": f"write a one sentence poem {time.time()}?"} - ] - start_time = time.time() - router = Router(model_list=model_list, - redis_host=os.environ["REDIS_HOST"], - redis_password=os.environ["REDIS_PASSWORD"], - redis_port=os.environ["REDIS_PORT"], - cache_responses=True, - timeout=30, - routing_strategy="simple-shuffle") - response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1) - print(f"response1: {response1}") - await asyncio.sleep(1) # add cache is async, async sleep for cache to get set - response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1) - print(f"response2: {response2}") - assert response1.id == response2.id - assert len(response1.choices[0].message.content) > 0 - assert response1.choices[0].message.content == response2.choices[0].message.content - router.reset() - except litellm.Timeout as e: - end_time = time.time() - print(f"timeout error occurred: {end_time - start_time}") - pass - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") @pytest.mark.asyncio -async def test_acompletion_caching_on_router_caching_groups(): - # tests acompletion + caching on router - try: - litellm.set_verbose = True - model_list = [ - { - "model_name": "openai-gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "tpm": 100000, - "rpm": 10000, - }, - { - "model_name": "azure-gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION") - }, - "tpm": 100000, - "rpm": 10000, - } - ] - - messages = [ - {"role": "user", "content": f"write a one sentence poem {time.time()}?"} - ] - start_time = time.time() - router = Router(model_list=model_list, - redis_host=os.environ["REDIS_HOST"], - redis_password=os.environ["REDIS_PASSWORD"], - redis_port=os.environ["REDIS_PORT"], - cache_responses=True, - timeout=30, - routing_strategy="simple-shuffle", - caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")]) - response1 = await router.acompletion(model="openai-gpt-3.5-turbo", messages=messages, temperature=1) - print(f"response1: {response1}") - await asyncio.sleep(1) # add cache is async, async sleep for cache to get set - response2 = await router.acompletion(model="azure-gpt-3.5-turbo", messages=messages, temperature=1) - print(f"response2: {response2}") - assert response1.id == response2.id - assert len(response1.choices[0].message.content) > 0 - assert response1.choices[0].message.content == response2.choices[0].message.content - router.reset() - except litellm.Timeout as e: - end_time = time.time() - print(f"timeout error occurred: {end_time - start_time}") - pass - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") - -def test_usage_based_routing_completion(): - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0301", - "api_key": os.getenv("OPENAI_API_KEY"), - "custom_llm_provider": "Custom-LLM" - }, - "tpm": 10000, - "rpm": 5 - }, - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0301", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "tpm": 10000, - "rpm": 5 - } - ] - router = Router(model_list= model_list, - routing_strategy= "usage-based-routing", - set_verbose= False) - max_requests = 5 - while max_requests > 0: +async def test_acompletion_caching_on_router(): + # tests acompletion + caching on router try: - router.completion( - model='gpt-3.5-turbo', - messages=[{"content": "write a one sentence poem.", "role": "user"}], - ) - except ValueError as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") - finally: - max_requests -= 1 - router.reset() \ No newline at end of file + litellm.set_verbose = True + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 100000, + "rpm": 10000, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + + messages = [ + {"role": "user", "content": f"write a one sentence poem {time.time()}?"} + ] + start_time = time.time() + router = Router( + model_list=model_list, + redis_host=os.environ["REDIS_HOST"], + redis_password=os.environ["REDIS_PASSWORD"], + redis_port=os.environ["REDIS_PORT"], + cache_responses=True, + timeout=30, + routing_strategy="simple-shuffle", + ) + response1 = await router.acompletion( + model="gpt-3.5-turbo", messages=messages, temperature=1 + ) + print(f"response1: {response1}") + await asyncio.sleep(1) # add cache is async, async sleep for cache to get set + response2 = await router.acompletion( + model="gpt-3.5-turbo", messages=messages, temperature=1 + ) + print(f"response2: {response2}") + assert response1.id == response2.id + assert len(response1.choices[0].message.content) > 0 + assert ( + response1.choices[0].message.content == response2.choices[0].message.content + ) + router.reset() + except litellm.Timeout as e: + end_time = time.time() + print(f"timeout error occurred: {end_time - start_time}") + pass + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio +async def test_acompletion_caching_on_router_caching_groups(): + # tests acompletion + caching on router + try: + litellm.set_verbose = True + model_list = [ + { + "model_name": "openai-gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 100000, + "rpm": 10000, + }, + { + "model_name": "azure-gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + + messages = [ + {"role": "user", "content": f"write a one sentence poem {time.time()}?"} + ] + start_time = time.time() + router = Router( + model_list=model_list, + redis_host=os.environ["REDIS_HOST"], + redis_password=os.environ["REDIS_PASSWORD"], + redis_port=os.environ["REDIS_PORT"], + cache_responses=True, + timeout=30, + routing_strategy="simple-shuffle", + caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")], + ) + response1 = await router.acompletion( + model="openai-gpt-3.5-turbo", messages=messages, temperature=1 + ) + print(f"response1: {response1}") + await asyncio.sleep(1) # add cache is async, async sleep for cache to get set + response2 = await router.acompletion( + model="azure-gpt-3.5-turbo", messages=messages, temperature=1 + ) + print(f"response2: {response2}") + assert response1.id == response2.id + assert len(response1.choices[0].message.content) > 0 + assert ( + response1.choices[0].message.content == response2.choices[0].message.content + ) + router.reset() + except litellm.Timeout as e: + end_time = time.time() + print(f"timeout error occurred: {end_time - start_time}") + pass + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + + +def test_usage_based_routing_completion(): + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0301", + "api_key": os.getenv("OPENAI_API_KEY"), + "custom_llm_provider": "Custom-LLM", + }, + "tpm": 10000, + "rpm": 5, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0301", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 10000, + "rpm": 5, + }, + ] + router = Router( + model_list=model_list, routing_strategy="usage-based-routing", set_verbose=False + ) + max_requests = 5 + while max_requests > 0: + try: + router.completion( + model="gpt-3.5-turbo", + messages=[{"content": "write a one sentence poem.", "role": "user"}], + ) + except ValueError as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + finally: + max_requests -= 1 + router.reset() diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index 22b5f121e..bfd3ab3c7 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -1,9 +1,10 @@ #### What this tests #### -# This tests calling router with fallback models +# This tests calling router with fallback models import sys, os, time import traceback, asyncio import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path @@ -12,177 +13,204 @@ import litellm from litellm import Router from litellm.integrations.custom_logger import CustomLogger + class MyCustomHandler(CustomLogger): success: bool = False failure: bool = False previous_models: int = 0 - def log_pre_api_call(self, model, messages, kwargs): + def log_pre_api_call(self, model, messages, kwargs): print(f"Pre-API Call") - - def log_post_api_call(self, kwargs, response_obj, start_time, end_time): - print(f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}") - + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + print( + f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}" + ) + def log_stream_event(self, kwargs, response_obj, start_time, end_time): print(f"On Stream") - + def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): print(f"On Stream") - - def log_success_event(self, kwargs, response_obj, start_time, end_time): - print(f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}") - self.previous_models += len(kwargs["litellm_params"]["metadata"]["previous_models"]) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": }]} - print(f"self.previous_models: {self.previous_models}") - print(f"On Success") - - async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - print(f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}") - self.previous_models += len(kwargs["litellm_params"]["metadata"]["previous_models"]) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": }]} + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + print( + f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}" + ) + self.previous_models += len( + kwargs["litellm_params"]["metadata"]["previous_models"] + ) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": }]} print(f"self.previous_models: {self.previous_models}") print(f"On Success") - def log_failure_event(self, kwargs, response_obj, start_time, end_time): + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print( + f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}" + ) + self.previous_models += len( + kwargs["litellm_params"]["metadata"]["previous_models"] + ) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": }]} + print(f"self.previous_models: {self.previous_models}") + print(f"On Success") + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Failure") -kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]} +kwargs = { + "model": "azure/gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hey, how's it going?"}], +} -def test_sync_fallbacks(): + +def test_sync_fallbacks(): try: model_list = [ - { # list of model deployments - "model_name": "azure/gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", "api_key": "bad-key", "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") + "api_base": os.getenv("AZURE_API_BASE"), }, "tpm": 240000, - "rpm": 1800 - }, - { # list of model deployments - "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", + "rpm": 1800, + }, + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", "api_key": os.getenv("AZURE_API_KEY"), "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") + "api_base": os.getenv("AZURE_API_BASE"), }, "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "azure/gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-functioncalling", - "api_key": "bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "tpm": 1000000, - "rpm": 9000 + "rpm": 1800, }, { - "model_name": "gpt-3.5-turbo-16k", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo-16k", + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", "api_key": os.getenv("OPENAI_API_KEY"), }, "tpm": 1000000, - "rpm": 9000 - } + "rpm": 9000, + }, + { + "model_name": "gpt-3.5-turbo-16k", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo-16k", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000, + }, ] litellm.set_verbose = True customHandler = MyCustomHandler() litellm.callbacks = [customHandler] - router = Router(model_list=model_list, - fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], - context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}], - set_verbose=False) + router = Router( + model_list=model_list, + fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], + context_window_fallbacks=[ + {"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, + {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}, + ], + set_verbose=False, + ) response = router.completion(**kwargs) print(f"response: {response}") - time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread - assert customHandler.previous_models == 1 # 0 retries, 1 fallback + time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread + assert customHandler.previous_models == 1 # 0 retries, 1 fallback print("Passed ! Test router_fallbacks: test_sync_fallbacks()") router.reset() except Exception as e: print(e) -# test_sync_fallbacks() + + +# test_sync_fallbacks() + @pytest.mark.asyncio -async def test_async_fallbacks(): +async def test_async_fallbacks(): litellm.set_verbose = False model_list = [ - { # list of model deployments - "model_name": "azure/gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", "api_key": "bad-key", "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") + "api_base": os.getenv("AZURE_API_BASE"), }, "tpm": 240000, - "rpm": 1800 - }, - { # list of model deployments - "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", + "rpm": 1800, + }, + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", "api_key": os.getenv("AZURE_API_KEY"), "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") + "api_base": os.getenv("AZURE_API_BASE"), }, "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "azure/gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-functioncalling", - "api_key": "bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "tpm": 1000000, - "rpm": 9000 + "rpm": 1800, }, { - "model_name": "gpt-3.5-turbo-16k", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo-16k", + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", "api_key": os.getenv("OPENAI_API_KEY"), }, "tpm": 1000000, - "rpm": 9000 - } + "rpm": 9000, + }, + { + "model_name": "gpt-3.5-turbo-16k", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo-16k", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000, + }, ] - router = Router(model_list=model_list, - fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], - context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}], - set_verbose=False) + router = Router( + model_list=model_list, + fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], + context_window_fallbacks=[ + {"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, + {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}, + ], + set_verbose=False, + ) customHandler = MyCustomHandler() litellm.callbacks = [customHandler] user_message = "Hello, how are you?" @@ -190,152 +218,158 @@ async def test_async_fallbacks(): try: response = await router.acompletion(**kwargs) print(f"customHandler.previous_models: {customHandler.previous_models}") - await asyncio.sleep(0.05) # allow a delay as success_callbacks are on a separate thread - assert customHandler.previous_models == 1 # 0 retries, 1 fallback + await asyncio.sleep( + 0.05 + ) # allow a delay as success_callbacks are on a separate thread + assert customHandler.previous_models == 1 # 0 retries, 1 fallback router.reset() - except litellm.Timeout as e: + except litellm.Timeout as e: pass except Exception as e: pytest.fail(f"An exception occurred: {e}") finally: router.reset() + # test_async_fallbacks() -def test_dynamic_fallbacks_sync(): + +def test_dynamic_fallbacks_sync(): """ - Allow setting the fallback in the router.completion() call. + Allow setting the fallback in the router.completion() call. """ try: - customHandler = MyCustomHandler() - litellm.callbacks = [customHandler] - model_list = [ - { # list of model deployments - "model_name": "azure/gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", + customHandler = MyCustomHandler() + litellm.callbacks = [customHandler] + model_list = [ + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", "api_key": "bad-key", "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") + "api_base": os.getenv("AZURE_API_BASE"), }, "tpm": 240000, - "rpm": 1800 - }, - { # list of model deployments - "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", + "rpm": 1800, + }, + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", "api_key": os.getenv("AZURE_API_KEY"), "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") + "api_base": os.getenv("AZURE_API_BASE"), }, "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "azure/gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-functioncalling", - "api_key": "bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "tpm": 1000000, - "rpm": 9000 + "rpm": 1800, }, { - "model_name": "gpt-3.5-turbo-16k", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo-16k", + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", "api_key": os.getenv("OPENAI_API_KEY"), }, "tpm": 1000000, - "rpm": 9000 - } + "rpm": 9000, + }, + { + "model_name": "gpt-3.5-turbo-16k", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo-16k", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000, + }, ] - router = Router(model_list=model_list, set_verbose=True) - kwargs = {} - kwargs["model"] = "azure/gpt-3.5-turbo" - kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}] - kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}] - response = router.completion(**kwargs) - print(f"response: {response}") - time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread - assert customHandler.previous_models == 1 # 0 retries, 1 fallback - router.reset() + router = Router(model_list=model_list, set_verbose=True) + kwargs = {} + kwargs["model"] = "azure/gpt-3.5-turbo" + kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}] + kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}] + response = router.completion(**kwargs) + print(f"response: {response}") + time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread + assert customHandler.previous_models == 1 # 0 retries, 1 fallback + router.reset() except Exception as e: pytest.fail(f"An exception occurred - {e}") + # test_dynamic_fallbacks_sync() + @pytest.mark.asyncio -async def test_dynamic_fallbacks_async(): +async def test_dynamic_fallbacks_async(): """ - Allow setting the fallback in the router.completion() call. + Allow setting the fallback in the router.completion() call. """ - try: + try: model_list = [ - { # list of model deployments - "model_name": "azure/gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", "api_key": "bad-key", "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") + "api_base": os.getenv("AZURE_API_BASE"), }, "tpm": 240000, - "rpm": 1800 - }, - { # list of model deployments - "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", + "rpm": 1800, + }, + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", "api_key": os.getenv("AZURE_API_KEY"), "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") + "api_base": os.getenv("AZURE_API_BASE"), }, "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "azure/gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-functioncalling", - "api_key": "bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - "tpm": 240000, - "rpm": 1800 - }, - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "tpm": 1000000, - "rpm": 9000 + "rpm": 1800, }, { - "model_name": "gpt-3.5-turbo-16k", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo-16k", + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", "api_key": os.getenv("OPENAI_API_KEY"), }, "tpm": 1000000, - "rpm": 9000 - } + "rpm": 9000, + }, + { + "model_name": "gpt-3.5-turbo-16k", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo-16k", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000, + }, ] print() @@ -352,9 +386,13 @@ async def test_dynamic_fallbacks_async(): kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}] response = await router.acompletion(**kwargs) print(f"RESPONSE: {response}") - await asyncio.sleep(0.05) # allow a delay as success_callbacks are on a separate thread - assert customHandler.previous_models == 1 # 0 retries, 1 fallback + await asyncio.sleep( + 0.05 + ) # allow a delay as success_callbacks are on a separate thread + assert customHandler.previous_models == 1 # 0 retries, 1 fallback router.reset() except Exception as e: pytest.fail(f"An exception occurred - {e}") -# asyncio.run(test_dynamic_fallbacks_async()) \ No newline at end of file + + +# asyncio.run(test_dynamic_fallbacks_async()) diff --git a/litellm/tests/test_router_get_deployments.py b/litellm/tests/test_router_get_deployments.py index a71fc3823..0a0fcee62 100644 --- a/litellm/tests/test_router_get_deployments.py +++ b/litellm/tests/test_router_get_deployments.py @@ -4,6 +4,7 @@ import sys, os, time import traceback, asyncio import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path @@ -12,342 +13,365 @@ from litellm import Router from concurrent.futures import ThreadPoolExecutor from collections import defaultdict from dotenv import load_dotenv + load_dotenv() -def test_weighted_selection_router(): - # this tests if load balancing works based on the provided rpms in the router - # it's a fast test, only tests get_available_deployment - # users can pass rpms as a litellm_param - try: - litellm.set_verbose = False - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": os.getenv("OPENAI_API_KEY"), - "rpm": 6, - }, - }, - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION"), - "rpm": 1440, - }, - } - ] - router = Router( - model_list=model_list, - ) - selection_counts = defaultdict(int) - # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time - for _ in range(1000): - selected_model = router.get_available_deployment("gpt-3.5-turbo") - selected_model_id = selected_model["litellm_params"]["model"] - selected_model_name = selected_model_id - selection_counts[selected_model_name] +=1 - print(selection_counts) +def test_weighted_selection_router(): + # this tests if load balancing works based on the provided rpms in the router + # it's a fast test, only tests get_available_deployment + # users can pass rpms as a litellm_param + try: + litellm.set_verbose = False + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + "rpm": 6, + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + "rpm": 1440, + }, + }, + ] + router = Router( + model_list=model_list, + ) + selection_counts = defaultdict(int) - total_requests = sum(selection_counts.values()) + # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time + for _ in range(1000): + selected_model = router.get_available_deployment("gpt-3.5-turbo") + selected_model_id = selected_model["litellm_params"]["model"] + selected_model_name = selected_model_id + selection_counts[selected_model_name] += 1 + print(selection_counts) - # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests - assert selection_counts['azure/chatgpt-v-2'] / total_requests > 0.89, f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}" + total_requests = sum(selection_counts.values()) + + # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests + assert ( + selection_counts["azure/chatgpt-v-2"] / total_requests > 0.89 + ), f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}" + + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") # test_weighted_selection_router() -def test_weighted_selection_router_tpm(): - # this tests if load balancing works based on the provided tpms in the router - # it's a fast test, only tests get_available_deployment - # users can pass rpms as a litellm_param - try: - print("\ntest weighted selection based on TPM\n") - litellm.set_verbose = False - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": os.getenv("OPENAI_API_KEY"), - "tpm": 5, - }, - }, - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION"), - "tpm": 90, - }, - } - ] - router = Router( - model_list=model_list, - ) - selection_counts = defaultdict(int) - # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time - for _ in range(1000): - selected_model = router.get_available_deployment("gpt-3.5-turbo") - selected_model_id = selected_model["litellm_params"]["model"] - selected_model_name = selected_model_id - selection_counts[selected_model_name] +=1 - print(selection_counts) +def test_weighted_selection_router_tpm(): + # this tests if load balancing works based on the provided tpms in the router + # it's a fast test, only tests get_available_deployment + # users can pass rpms as a litellm_param + try: + print("\ntest weighted selection based on TPM\n") + litellm.set_verbose = False + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + "tpm": 5, + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + "tpm": 90, + }, + }, + ] + router = Router( + model_list=model_list, + ) + selection_counts = defaultdict(int) - total_requests = sum(selection_counts.values()) + # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time + for _ in range(1000): + selected_model = router.get_available_deployment("gpt-3.5-turbo") + selected_model_id = selected_model["litellm_params"]["model"] + selected_model_name = selected_model_id + selection_counts[selected_model_name] += 1 + print(selection_counts) - # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests - assert selection_counts['azure/chatgpt-v-2'] / total_requests > 0.89, f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}" + total_requests = sum(selection_counts.values()) + + # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests + assert ( + selection_counts["azure/chatgpt-v-2"] / total_requests > 0.89 + ), f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}" + + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") # test_weighted_selection_router_tpm() -def test_weighted_selection_router_tpm_as_router_param(): - # this tests if load balancing works based on the provided tpms in the router - # it's a fast test, only tests get_available_deployment - # users can pass rpms as a litellm_param - try: - print("\ntest weighted selection based on TPM\n") - litellm.set_verbose = False - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "tpm": 5, - }, - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION"), - }, - "tpm": 90, - } - ] - router = Router( - model_list=model_list, - ) - selection_counts = defaultdict(int) +def test_weighted_selection_router_tpm_as_router_param(): + # this tests if load balancing works based on the provided tpms in the router + # it's a fast test, only tests get_available_deployment + # users can pass rpms as a litellm_param + try: + print("\ntest weighted selection based on TPM\n") + litellm.set_verbose = False + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 5, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + }, + "tpm": 90, + }, + ] + router = Router( + model_list=model_list, + ) + selection_counts = defaultdict(int) - # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time - for _ in range(1000): - selected_model = router.get_available_deployment("gpt-3.5-turbo") - selected_model_id = selected_model["litellm_params"]["model"] - selected_model_name = selected_model_id - selection_counts[selected_model_name] +=1 - print(selection_counts) + # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time + for _ in range(1000): + selected_model = router.get_available_deployment("gpt-3.5-turbo") + selected_model_id = selected_model["litellm_params"]["model"] + selected_model_name = selected_model_id + selection_counts[selected_model_name] += 1 + print(selection_counts) - total_requests = sum(selection_counts.values()) + total_requests = sum(selection_counts.values()) - # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests - assert selection_counts['azure/chatgpt-v-2'] / total_requests > 0.89, f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}" + # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests + assert ( + selection_counts["azure/chatgpt-v-2"] / total_requests > 0.89 + ), f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}" + + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") test_weighted_selection_router_tpm_as_router_param() +def test_weighted_selection_router_rpm_as_router_param(): + # this tests if load balancing works based on the provided tpms in the router + # it's a fast test, only tests get_available_deployment + # users can pass rpms as a litellm_param + try: + print("\ntest weighted selection based on RPM\n") + litellm.set_verbose = False + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "rpm": 5, + "tpm": 5, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + }, + "rpm": 90, + "tpm": 90, + }, + ] + router = Router( + model_list=model_list, + ) + selection_counts = defaultdict(int) -def test_weighted_selection_router_rpm_as_router_param(): - # this tests if load balancing works based on the provided tpms in the router - # it's a fast test, only tests get_available_deployment - # users can pass rpms as a litellm_param - try: - print("\ntest weighted selection based on RPM\n") - litellm.set_verbose = False - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - "rpm": 5, - "tpm": 5, - }, - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION"), - }, - "rpm": 90, - "tpm": 90, - } - ] - router = Router( - model_list=model_list, - ) - selection_counts = defaultdict(int) + # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time + for _ in range(1000): + selected_model = router.get_available_deployment("gpt-3.5-turbo") + selected_model_id = selected_model["litellm_params"]["model"] + selected_model_name = selected_model_id + selection_counts[selected_model_name] += 1 + print(selection_counts) - # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time - for _ in range(1000): - selected_model = router.get_available_deployment("gpt-3.5-turbo") - selected_model_id = selected_model["litellm_params"]["model"] - selected_model_name = selected_model_id - selection_counts[selected_model_name] +=1 - print(selection_counts) + total_requests = sum(selection_counts.values()) - total_requests = sum(selection_counts.values()) + # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests + assert ( + selection_counts["azure/chatgpt-v-2"] / total_requests > 0.89 + ), f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}" - # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests - assert selection_counts['azure/chatgpt-v-2'] / total_requests > 0.89, f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}" + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") # test_weighted_selection_router_tpm_as_router_param() +def test_weighted_selection_router_no_rpm_set(): + # this tests if we can do selection when no rpm is provided too + # it's a fast test, only tests get_available_deployment + # users can pass rpms as a litellm_param + try: + litellm.set_verbose = False + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + "rpm": 6, + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + "rpm": 1440, + }, + }, + { + "model_name": "claude-1", + "litellm_params": { + "model": "bedrock/claude1.2", + "rpm": 1440, + }, + }, + ] + router = Router( + model_list=model_list, + ) + selection_counts = defaultdict(int) -def test_weighted_selection_router_no_rpm_set(): - # this tests if we can do selection when no rpm is provided too - # it's a fast test, only tests get_available_deployment - # users can pass rpms as a litellm_param - try: - litellm.set_verbose = False - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": os.getenv("OPENAI_API_KEY"), - "rpm": 6, - }, - }, - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION"), - "rpm": 1440, - }, - }, - { - "model_name": "claude-1", - "litellm_params": { - "model": "bedrock/claude1.2", - "rpm": 1440, - }, - } - ] - router = Router( - model_list=model_list, - ) - selection_counts = defaultdict(int) + # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time + for _ in range(1000): + selected_model = router.get_available_deployment("claude-1") + selected_model_id = selected_model["litellm_params"]["model"] + selected_model_name = selected_model_id + selection_counts[selected_model_name] += 1 + print(selection_counts) - # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time - for _ in range(1000): - selected_model = router.get_available_deployment("claude-1") - selected_model_id = selected_model["litellm_params"]["model"] - selected_model_name = selected_model_id - selection_counts[selected_model_name] +=1 - print(selection_counts) + total_requests = sum(selection_counts.values()) - total_requests = sum(selection_counts.values()) + # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests + assert ( + selection_counts["bedrock/claude1.2"] / total_requests == 1 + ), f"Assertion failed: Selection counts {selection_counts}" - # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests - assert selection_counts['bedrock/claude1.2'] / total_requests == 1, f"Assertion failed: Selection counts {selection_counts}" + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") # test_weighted_selection_router_no_rpm_set() +def test_model_group_aliases(): + try: + litellm.set_verbose = False + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + "tpm": 1, + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION"), + "tpm": 99, + }, + }, + { + "model_name": "claude-1", + "litellm_params": { + "model": "bedrock/claude1.2", + "tpm": 1, + }, + }, + ] + router = Router( + model_list=model_list, + model_group_alias={ + "gpt-4": "gpt-3.5-turbo" + }, # gpt-4 requests sent to gpt-3.5-turbo + ) -def test_model_group_aliases(): - try: - litellm.set_verbose = False - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo-0613", - "api_key": os.getenv("OPENAI_API_KEY"), - "tpm": 1, - }, - }, - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION"), - "tpm": 99, - }, - }, - { - "model_name": "claude-1", - "litellm_params": { - "model": "bedrock/claude1.2", - "tpm": 1, - }, - } - ] - router = Router( - model_list=model_list, - model_group_alias={"gpt-4": "gpt-3.5-turbo"} # gpt-4 requests sent to gpt-3.5-turbo - ) + # test that gpt-4 requests are sent to gpt-3.5-turbo + for _ in range(20): + selected_model = router.get_available_deployment("gpt-4") + print("\n selected model", selected_model) + selected_model_name = selected_model.get("model_name") + if selected_model_name != "gpt-3.5-turbo": + pytest.fail( + f"Selected model {selected_model_name} is not gpt-3.5-turbo" + ) - # test that gpt-4 requests are sent to gpt-3.5-turbo - for _ in range(20): - selected_model = router.get_available_deployment("gpt-4") - print("\n selected model", selected_model) - selected_model_name = selected_model.get("model_name") - if selected_model_name != "gpt-3.5-turbo": - pytest.fail(f"Selected model {selected_model_name} is not gpt-3.5-turbo") - - # test that - # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time - selection_counts = defaultdict(int) - for _ in range(1000): - selected_model = router.get_available_deployment("gpt-3.5-turbo") - selected_model_id = selected_model["litellm_params"]["model"] - selected_model_name = selected_model_id - selection_counts[selected_model_name] +=1 - print(selection_counts) + # test that + # call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time + selection_counts = defaultdict(int) + for _ in range(1000): + selected_model = router.get_available_deployment("gpt-3.5-turbo") + selected_model_id = selected_model["litellm_params"]["model"] + selected_model_name = selected_model_id + selection_counts[selected_model_name] += 1 + print(selection_counts) - total_requests = sum(selection_counts.values()) + total_requests = sum(selection_counts.values()) - # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests - assert selection_counts['azure/chatgpt-v-2'] / total_requests > 0.89, f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}" + # Assert that 'azure/chatgpt-v-2' has about 90% of the total requests + assert ( + selection_counts["azure/chatgpt-v-2"] / total_requests > 0.89 + ), f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}" - router.reset() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") -# test_model_group_aliases() \ No newline at end of file + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + + +# test_model_group_aliases() diff --git a/litellm/tests/test_router_init.py b/litellm/tests/test_router_init.py index 4d861365e..e36b8319a 100644 --- a/litellm/tests/test_router_init.py +++ b/litellm/tests/test_router_init.py @@ -2,6 +2,7 @@ import sys, os, time import traceback, asyncio import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path @@ -10,6 +11,7 @@ from litellm import Router from concurrent.futures import ThreadPoolExecutor from collections import defaultdict from dotenv import load_dotenv + load_dotenv() # every time we load the router we should have 4 clients: @@ -18,6 +20,7 @@ load_dotenv() # Async + Stream # Sync + Stream + def test_init_clients(): litellm.set_verbose = True try: @@ -32,7 +35,7 @@ def test_init_clients(): "api_base": os.getenv("AZURE_API_BASE"), "timeout": 0.01, "stream_timeout": 0.000_001, - "max_retries": 7 + "max_retries": 7, }, }, ] @@ -42,7 +45,7 @@ def test_init_clients(): assert elem["async_client"] is not None assert elem["stream_client"] is not None assert elem["stream_async_client"] is not None - + # check if timeout for stream/non stream clients is set correctly async_client = elem["async_client"] stream_async_client = elem["stream_async_client"] @@ -55,6 +58,7 @@ def test_init_clients(): traceback.print_exc() pytest.fail(f"Error occurred: {e}") + # test_init_clients() @@ -80,20 +84,22 @@ def test_init_clients_basic(): assert elem["stream_client"] is not None assert elem["stream_async_client"] is not None print("PASSED !") - + # see if we can init clients without timeout or max retries set except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") + # test_init_clients_basic() def test_timeouts_router(): """ - Test the timeouts of the router with multiple clients. This HASas to raise a timeout error + Test the timeouts of the router with multiple clients. This HASas to raise a timeout error """ import openai + litellm.set_verbose = True try: print("testing init 4 clients with diff timeouts") @@ -111,28 +117,32 @@ def test_timeouts_router(): }, ] router = Router(model_list=model_list) - + print("PASSED !") + async def test(): try: await router.acompletion( model="gpt-3.5-turbo", messages=[ - { - "role": "user", - "content": "hello, write a 20 pg essay" - } + {"role": "user", "content": "hello, write a 20 pg essay"} ], ) except Exception as e: raise e + asyncio.run(test()) except openai.APITimeoutError as e: - print("Passed: Raised correct exception. Got openai.APITimeoutError\nGood Job", e) + print( + "Passed: Raised correct exception. Got openai.APITimeoutError\nGood Job", e + ) print(type(e)) pass except Exception as e: - pytest.fail(f"Did not raise error `openai.APITimeoutError`. Instead raised error type: {type(e)}, Error: {e}") + pytest.fail( + f"Did not raise error `openai.APITimeoutError`. Instead raised error type: {type(e)}, Error: {e}" + ) + # test_timeouts_router() @@ -142,7 +152,7 @@ def test_stream_timeouts_router(): Test the stream timeouts router. See if it selected the correct client with stream timeout """ import openai - + litellm.set_verbose = True try: print("testing init 4 clients with diff timeouts") @@ -154,37 +164,35 @@ def test_stream_timeouts_router(): "api_key": os.getenv("AZURE_API_KEY"), "api_version": os.getenv("AZURE_API_VERSION"), "api_base": os.getenv("AZURE_API_BASE"), - "timeout": 200, # regular calls will not timeout, stream calls will + "timeout": 200, # regular calls will not timeout, stream calls will "stream_timeout": 0.000_001, }, }, ] router = Router(model_list=model_list) - + print("PASSED !") selected_client = router._get_client( deployment=router.model_list[0], kwargs={ "model": "gpt-3.5-turbo", - "messages": [ - { - "role": "user", - "content": "hello, write a 20 pg essay" - } - ], - "stream": True + "messages": [{"role": "user", "content": "hello, write a 20 pg essay"}], + "stream": True, }, - client_type=None + client_type=None, ) print("Select client timeout", selected_client.timeout) - assert selected_client.timeout == 0.000_001 + assert selected_client.timeout == 0.000_001 except openai.APITimeoutError as e: - print("Passed: Raised correct exception. Got openai.APITimeoutError\nGood Job", e) + print( + "Passed: Raised correct exception. Got openai.APITimeoutError\nGood Job", e + ) print(type(e)) pass except Exception as e: - pytest.fail(f"Did not raise error `openai.APITimeoutError`. Instead raised error type: {type(e)}, Error: {e}") + pytest.fail( + f"Did not raise error `openai.APITimeoutError`. Instead raised error type: {type(e)}, Error: {e}" + ) + test_stream_timeouts_router() - - diff --git a/litellm/tests/test_rules.py b/litellm/tests/test_rules.py index 7905babc5..664b9db08 100644 --- a/litellm/tests/test_rules.py +++ b/litellm/tests/test_rules.py @@ -1,22 +1,25 @@ #### What this tests #### -# This tests setting rules before / after making llm api calls +# This tests setting rules before / after making llm api calls import sys, os, time import traceback, asyncio import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import litellm from litellm import completion, acompletion -def my_pre_call_rule(input: str): + +def my_pre_call_rule(input: str): print(f"input: {input}") print(f"INSIDE MY PRE CALL RULE, len(input) - {len(input)}") - if len(input) > 10: + if len(input) > 10: return False return True -def my_post_call_rule(input: str): + +def my_post_call_rule(input: str): input = input.lower() print(f"input: {input}") print(f"INSIDE MY POST CALL RULE, len(input) - {len(input)}") @@ -24,16 +27,20 @@ def my_post_call_rule(input: str): return False return True -## Test 1: Pre-call rule + +## Test 1: Pre-call rule def test_pre_call_rule(): - try: + try: litellm.pre_call_rules = [my_pre_call_rule] - ### completion - response = completion(model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "say something inappropriate"}]) + ### completion + response = completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "say something inappropriate"}], + ) pytest.fail(f"Completion call should have been failed. ") - except: + except: pass + ### async completion async def test_async_response(): user_message = "Hello, how are you?" @@ -43,22 +50,24 @@ def test_pre_call_rule(): pytest.fail(f"acompletion call should have been failed. ") except Exception as e: pass + asyncio.run(test_async_response()) litellm.pre_call_rules = [] -# test_pre_call_rule() -## Test 2: Post-call rule + +# test_pre_call_rule() +## Test 2: Post-call rule # commenting out of ci/cd since llm's have variable output which was causing our pipeline to fail erratically. # def test_post_call_rule(): -# try: +# try: # litellm.pre_call_rules = [] # litellm.post_call_rules = [my_post_call_rule] -# ### completion -# response = completion(model="gpt-3.5-turbo", +# ### completion +# response = completion(model="gpt-3.5-turbo", # messages=[{"role": "user", "content": "say sorry"}], # fallbacks=["deepinfra/Gryphe/MythoMax-L2-13b"]) # pytest.fail(f"Completion call should have been failed. ") -# except: +# except: # pass # print(f"MAKING ACOMPLETION CALL") # # litellm.set_verbose = True @@ -74,4 +83,4 @@ def test_pre_call_rule(): # litellm.pre_call_rules = [] # litellm.post_call_rules = [] -# test_post_call_rule() \ No newline at end of file +# test_post_call_rule() diff --git a/litellm/tests/test_stream_chunk_builder.py b/litellm/tests/test_stream_chunk_builder.py index 23f67a2e8..3caaf5377 100644 --- a/litellm/tests/test_stream_chunk_builder.py +++ b/litellm/tests/test_stream_chunk_builder.py @@ -1,6 +1,7 @@ import sys, os, time import traceback, asyncio import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path @@ -9,55 +10,51 @@ import litellm import os, dotenv from openai import OpenAI import pytest + dotenv.load_dotenv() user_message = "What is the current weather in Boston?" messages = [{"content": user_message, "role": "user"}] function_schema = { - "name": "get_weather", - "description": - "gets the current weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": - "The city and state, e.g. San Francisco, CA" - }, + "name": "get_weather", + "description": "gets the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + }, + "required": ["location"], }, - "required": ["location"] - }, } tools_schema = [ { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } - }, - "required": ["location"] - } - } + }, } - ] +] # def test_stream_chunk_builder_tools(): -# try: +# try: # litellm.set_verbose = False # response = client.chat.completions.create( # model="gpt-3.5-turbo", @@ -69,67 +66,73 @@ tools_schema = [ # print(f"response: {response}") # print(f"response usage: {response.usage}") -# except Exception as e: +# except Exception as e: # pytest.fail(f"An exception occurred - {str(e)}") # test_stream_chunk_builder_tools() -def test_stream_chunk_builder_litellm_function_call(): - try: - litellm.set_verbose = False - response = litellm.completion( - model="gpt-3.5-turbo", - messages=messages, - functions=[function_schema], - # stream=True, - # complete_response=True # runs stream_chunk_builder under-the-hood - ) - print(f"response: {response}") - except Exception as e: - pytest.fail(f"An exception occurred - {str(e)}") +def test_stream_chunk_builder_litellm_function_call(): + try: + litellm.set_verbose = False + response = litellm.completion( + model="gpt-3.5-turbo", + messages=messages, + functions=[function_schema], + # stream=True, + # complete_response=True # runs stream_chunk_builder under-the-hood + ) + + print(f"response: {response}") + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}") + # test_stream_chunk_builder_litellm_function_call() -def test_stream_chunk_builder_litellm_tool_call(): - try: - litellm.set_verbose = False - response = litellm.completion( - model="azure/gpt-4-nov-release", - messages=messages, - tools=tools_schema, - stream=True, - api_key="os.environ/AZURE_FRANCE_API_KEY", - api_base="https://openai-france-1234.openai.azure.com", - complete_response = True - ) - print(f"complete response: {response}") - print(f"complete response usage: {response.usage}") - assert response.system_fingerprint is not None - except Exception as e: - pytest.fail(f"An exception occurred - {str(e)}") +def test_stream_chunk_builder_litellm_tool_call(): + try: + litellm.set_verbose = False + response = litellm.completion( + model="azure/gpt-4-nov-release", + messages=messages, + tools=tools_schema, + stream=True, + api_key="os.environ/AZURE_FRANCE_API_KEY", + api_base="https://openai-france-1234.openai.azure.com", + complete_response=True, + ) + + print(f"complete response: {response}") + print(f"complete response usage: {response.usage}") + assert response.system_fingerprint is not None + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}") + # test_stream_chunk_builder_litellm_tool_call() -def test_stream_chunk_builder_litellm_tool_call_regular_message(): - try: - messages = [{"role": "user", "content": "Hey, how's it going?"}] - litellm.set_verbose = False - response = litellm.completion( - model="azure/gpt-4-nov-release", - messages=messages, - tools=tools_schema, - stream=True, - api_key="os.environ/AZURE_FRANCE_API_KEY", - api_base="https://openai-france-1234.openai.azure.com", - complete_response = True - ) - print(f"complete response: {response}") - print(f"complete response usage: {response.usage}") - assert response.system_fingerprint is not None - except Exception as e: - pytest.fail(f"An exception occurred - {str(e)}") +def test_stream_chunk_builder_litellm_tool_call_regular_message(): + try: + messages = [{"role": "user", "content": "Hey, how's it going?"}] + litellm.set_verbose = False + response = litellm.completion( + model="azure/gpt-4-nov-release", + messages=messages, + tools=tools_schema, + stream=True, + api_key="os.environ/AZURE_FRANCE_API_KEY", + api_base="https://openai-france-1234.openai.azure.com", + complete_response=True, + ) + + print(f"complete response: {response}") + print(f"complete response usage: {response.usage}") + assert response.system_fingerprint is not None + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}") + test_stream_chunk_builder_litellm_tool_call_regular_message() diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 818ee6664..5d15e6f2c 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -9,9 +9,17 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path from dotenv import load_dotenv + load_dotenv() import litellm -from litellm import completion, acompletion, AuthenticationError, BadRequestError, RateLimitError, ModelResponse +from litellm import ( + completion, + acompletion, + AuthenticationError, + BadRequestError, + RateLimitError, + ModelResponse, +) litellm.logging = False litellm.set_verbose = True @@ -37,30 +45,31 @@ first_openai_chunk_example = { "choices": [ { "index": 0, - "delta": { - "role": "assistant", - "content": "" - }, - "finish_reason": None # it's null + "delta": {"role": "assistant", "content": ""}, + "finish_reason": None, # it's null } - ] + ], } + def validate_first_format(chunk): # write a test to make sure chunk follows the same format as first_openai_chunk_example assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary." - assert isinstance(chunk['id'], str), "'id' should be a string." - assert isinstance(chunk['object'], str), "'object' should be a string." - assert isinstance(chunk['created'], int), "'created' should be an integer." - assert isinstance(chunk['model'], str), "'model' should be a string." - assert isinstance(chunk['choices'], list), "'choices' should be a list." + assert isinstance(chunk["id"], str), "'id' should be a string." + assert isinstance(chunk["object"], str), "'object' should be a string." + assert isinstance(chunk["created"], int), "'created' should be an integer." + assert isinstance(chunk["model"], str), "'model' should be a string." + assert isinstance(chunk["choices"], list), "'choices' should be a list." - for choice in chunk['choices']: - assert isinstance(choice['index'], int), "'index' should be an integer." - assert isinstance(choice['delta']['role'], str), "'role' should be a string." + for choice in chunk["choices"]: + assert isinstance(choice["index"], int), "'index' should be an integer." + assert isinstance(choice["delta"]["role"], str), "'role' should be a string." assert "messages" not in choice # openai v1.0.0 returns content as None - assert (choice['finish_reason'] is None) or isinstance(choice['finish_reason'], str), "'finish_reason' should be None or a string." + assert (choice["finish_reason"] is None) or isinstance( + choice["finish_reason"], str + ), "'finish_reason' should be None or a string." + second_openai_chunk_example = { "id": "chatcmpl-7zSKLBVXnX9dwgRuDYVqVVDsgh2yp", @@ -68,99 +77,98 @@ second_openai_chunk_example = { "created": 1694881253, "model": "gpt-4-0613", "choices": [ - { - "index": 0, - "delta": { - "content": "Hello" - }, - "finish_reason": None # it's null - } - ] + {"index": 0, "delta": {"content": "Hello"}, "finish_reason": None} # it's null + ], } + def validate_second_format(chunk): assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary." - assert isinstance(chunk['id'], str), "'id' should be a string." - assert isinstance(chunk['object'], str), "'object' should be a string." - assert isinstance(chunk['created'], int), "'created' should be an integer." - assert isinstance(chunk['model'], str), "'model' should be a string." - assert isinstance(chunk['choices'], list), "'choices' should be a list." + assert isinstance(chunk["id"], str), "'id' should be a string." + assert isinstance(chunk["object"], str), "'object' should be a string." + assert isinstance(chunk["created"], int), "'created' should be an integer." + assert isinstance(chunk["model"], str), "'model' should be a string." + assert isinstance(chunk["choices"], list), "'choices' should be a list." - for choice in chunk['choices']: - assert isinstance(choice['index'], int), "'index' should be an integer." + for choice in chunk["choices"]: + assert isinstance(choice["index"], int), "'index' should be an integer." # openai v1.0.0 returns content as None - assert (choice['finish_reason'] is None) or isinstance(choice['finish_reason'], str), "'finish_reason' should be None or a string." + assert (choice["finish_reason"] is None) or isinstance( + choice["finish_reason"], str + ), "'finish_reason' should be None or a string." + last_openai_chunk_example = { "id": "chatcmpl-7zSKLBVXnX9dwgRuDYVqVVDsgh2yp", "object": "chat.completion.chunk", "created": 1694881253, "model": "gpt-4-0613", - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "stop" - } - ] + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], } + def validate_last_format(chunk): assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary." - assert isinstance(chunk['id'], str), "'id' should be a string." - assert isinstance(chunk['object'], str), "'object' should be a string." - assert isinstance(chunk['created'], int), "'created' should be an integer." - assert isinstance(chunk['model'], str), "'model' should be a string." - assert isinstance(chunk['choices'], list), "'choices' should be a list." + assert isinstance(chunk["id"], str), "'id' should be a string." + assert isinstance(chunk["object"], str), "'object' should be a string." + assert isinstance(chunk["created"], int), "'created' should be an integer." + assert isinstance(chunk["model"], str), "'model' should be a string." + assert isinstance(chunk["choices"], list), "'choices' should be a list." + + for choice in chunk["choices"]: + assert isinstance(choice["index"], int), "'index' should be an integer." + assert isinstance( + choice["finish_reason"], str + ), "'finish_reason' should be a string." - for choice in chunk['choices']: - assert isinstance(choice['index'], int), "'index' should be an integer." - assert isinstance(choice['finish_reason'], str), "'finish_reason' should be a string." def streaming_format_tests(idx, chunk): - extracted_chunk = "" + extracted_chunk = "" finished = False print(f"chunk: {chunk}") - if idx == 0: # ensure role assistant is set + if idx == 0: # ensure role assistant is set validate_first_format(chunk=chunk) role = chunk["choices"][0]["delta"]["role"] assert role == "assistant" - elif idx == 1: # second chunk + elif idx == 1: # second chunk validate_second_format(chunk=chunk) - if idx != 0: # ensure no role + if idx != 0: # ensure no role if "role" in chunk["choices"][0]["delta"]: - pass # openai v1.0.0+ passes role = None - if chunk["choices"][0]["finish_reason"]: # ensure finish reason is only in last chunk + pass # openai v1.0.0+ passes role = None + if chunk["choices"][0][ + "finish_reason" + ]: # ensure finish reason is only in last chunk validate_last_format(chunk=chunk) finished = True - if "content" in chunk["choices"][0]["delta"] and chunk["choices"][0]["delta"]["content"] is not None: + if ( + "content" in chunk["choices"][0]["delta"] + and chunk["choices"][0]["delta"]["content"] is not None + ): extracted_chunk = chunk["choices"][0]["delta"]["content"] print(f"extracted chunk: {extracted_chunk}") return extracted_chunk, finished + tools_schema = [ { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } - }, - "required": ["location"] - } - } + }, } - ] +] # def test_completion_cohere_stream(): # # this is a flaky test due to the cohere API endpoint being unstable @@ -186,7 +194,7 @@ tools_schema = [ # complete_response += chunk # if has_finish_reason is False: # raise Exception("Finish reason not in final chunk") -# if complete_response.strip() == "": +# if complete_response.strip() == "": # raise Exception("Empty response received") # print(f"completion_response: {complete_response}") # except Exception as e: @@ -194,6 +202,7 @@ tools_schema = [ # test_completion_cohere_stream() + def test_completion_cohere_stream_bad_key(): try: litellm.cache = None @@ -206,7 +215,11 @@ def test_completion_cohere_stream_bad_key(): }, ] response = completion( - model="command-nightly", messages=messages, stream=True, max_tokens=50, api_key=api_key + model="command-nightly", + messages=messages, + stream=True, + max_tokens=50, + api_key=api_key, ) complete_response = "" # Add any assertions here to check the response @@ -219,16 +232,18 @@ def test_completion_cohere_stream_bad_key(): complete_response += chunk if has_finish_reason is False: raise Exception("Finish reason not in final chunk") - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") - except AuthenticationError as e: + except AuthenticationError as e: pass except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_cohere_stream_bad_key() + def test_completion_azure_stream(): try: litellm.set_verbose = False @@ -250,11 +265,14 @@ def test_completion_azure_stream(): if finished: assert isinstance(init_chunk.choices[0], litellm.utils.StreamingChoices) break - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_azure_stream() + + +# test_completion_azure_stream() + def test_completion_azure_function_calling_stream(): try: @@ -262,7 +280,10 @@ def test_completion_azure_function_calling_stream(): user_message = "What is the current weather in Boston?" messages = [{"content": user_message, "role": "user"}] response = completion( - model="azure/chatgpt-functioncalling", messages=messages, stream=True, tools=tools_schema + model="azure/chatgpt-functioncalling", + messages=messages, + stream=True, + tools=tools_schema, ) # Add any assertions here to check the response for chunk in response: @@ -274,9 +295,11 @@ def test_completion_azure_function_calling_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_azure_function_calling_stream() -def test_completion_ollama_hosted_stream(): + +def test_completion_ollama_hosted_stream(): try: litellm.set_verbose = True response = completion( @@ -284,7 +307,7 @@ def test_completion_ollama_hosted_stream(): messages=messages, max_tokens=10, api_base="https://test-ollama-endpoint.onrender.com", - stream=True + stream=True, ) # Add any assertions here to check the response complete_response = "" @@ -295,14 +318,16 @@ def test_completion_ollama_hosted_stream(): if finished: assert isinstance(init_chunk.choices[0], litellm.utils.StreamingChoices) break - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"complete_response: {complete_response}") except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_ollama_hosted_stream() + def test_completion_claude_stream(): try: messages = [ @@ -322,17 +347,19 @@ def test_completion_claude_stream(): if finished: break complete_response += chunk - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_claude_stream() def test_completion_palm_stream(): try: - litellm.set_verbose=False + litellm.set_verbose = False print("Streaming palm response") messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -342,9 +369,7 @@ def test_completion_palm_stream(): }, ] print("testing palm streaming") - response = completion( - model="palm/chat-bison", messages=messages, stream=True - ) + response = completion(model="palm/chat-bison", messages=messages, stream=True) complete_response = "" # Add any assertions here to check the response @@ -355,11 +380,13 @@ def test_completion_palm_stream(): if finished: break complete_response += chunk - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_palm_stream() @@ -368,7 +395,7 @@ def test_completion_mistral_api_stream(): litellm.set_verbose = True print("Testing streaming mistral api response") response = completion( - model="mistral/mistral-medium", + model="mistral/mistral-medium", messages=[ { "role": "user", @@ -376,7 +403,7 @@ def test_completion_mistral_api_stream(): } ], max_tokens=10, - stream=True + stream=True, ) complete_response = "" for idx, chunk in enumerate(response): @@ -386,15 +413,18 @@ def test_completion_mistral_api_stream(): if finished: break complete_response += chunk - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_mistral_api_stream() + def test_completion_deep_infra_stream(): - # deep infra currently includes role in the 2nd chunk + # deep infra currently includes role in the 2nd chunk # waiting for them to make a fix on this try: messages = [ @@ -406,7 +436,10 @@ def test_completion_deep_infra_stream(): ] print("testing deep infra streaming") response = completion( - model="deepinfra/meta-llama/Llama-2-70b-chat-hf", messages=messages, stream=True, max_tokens=80 + model="deepinfra/meta-llama/Llama-2-70b-chat-hf", + messages=messages, + stream=True, + max_tokens=80, ) complete_response = "" @@ -416,13 +449,16 @@ def test_completion_deep_infra_stream(): if finished: break complete_response += chunk - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_deep_infra_stream() + @pytest.mark.skip() def test_completion_nlp_cloud_stream(): try: @@ -435,7 +471,10 @@ def test_completion_nlp_cloud_stream(): ] print("testing nlp cloud streaming") response = completion( - model="nlp_cloud/finetuned-llama-2-70b", messages=messages, stream=True, max_tokens=20 + model="nlp_cloud/finetuned-llama-2-70b", + messages=messages, + stream=True, + max_tokens=20, ) complete_response = "" @@ -445,14 +484,17 @@ def test_completion_nlp_cloud_stream(): complete_response += chunk if finished: break - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") except Exception as e: print(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}") + + # test_completion_nlp_cloud_stream() + def test_completion_claude_stream_bad_key(): try: litellm.cache = None @@ -466,7 +508,11 @@ def test_completion_claude_stream_bad_key(): }, ] response = completion( - model="claude-instant-1", messages=messages, stream=True, max_tokens=50, api_key=api_key + model="claude-instant-1", + messages=messages, + stream=True, + max_tokens=50, + api_key=api_key, ) complete_response = "" # Add any assertions here to check the response @@ -475,7 +521,7 @@ def test_completion_claude_stream_bad_key(): if finished: break complete_response += chunk - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"1234completion_response: {complete_response}") raise Exception("Auth error not raised") @@ -485,12 +531,12 @@ def test_completion_claude_stream_bad_key(): pytest.fail(f"Error occurred: {e}") -# test_completion_claude_stream_bad_key() +# test_completion_claude_stream_bad_key() # test_completion_replicate_stream() # def test_completion_vertexai_stream(): # try: -# import os +# import os # os.environ["VERTEXAI_PROJECT"] = "pathrise-convert-1606954137718" # os.environ["VERTEXAI_LOCATION"] = "us-central1" # messages = [ @@ -514,10 +560,10 @@ def test_completion_claude_stream_bad_key(): # complete_response += chunk # if has_finish_reason is False: # raise Exception("finish reason not set for last chunk") -# if complete_response.strip() == "": +# if complete_response.strip() == "": # raise Exception("Empty response received") # print(f"completion_response: {complete_response}") -# except InvalidRequestError as e: +# except InvalidRequestError as e: # pass # except Exception as e: # pytest.fail(f"Error occurred: {e}") @@ -527,7 +573,7 @@ def test_completion_claude_stream_bad_key(): # def test_completion_vertexai_stream_bad_key(): # try: -# import os +# import os # messages = [ # {"role": "system", "content": "You are a helpful assistant."}, # { @@ -549,10 +595,10 @@ def test_completion_claude_stream_bad_key(): # complete_response += chunk # if has_finish_reason is False: # raise Exception("finish reason not set for last chunk") -# if complete_response.strip() == "": +# if complete_response.strip() == "": # raise Exception("Empty response received") # print(f"completion_response: {complete_response}") -# except InvalidRequestError as e: +# except InvalidRequestError as e: # pass # except Exception as e: # pytest.fail(f"Error occurred: {e}") @@ -584,14 +630,15 @@ def test_completion_claude_stream_bad_key(): # complete_response += chunk # if has_finish_reason is False: # raise Exception("finish reason not set for last chunk") -# if complete_response.strip() == "": +# if complete_response.strip() == "": # raise Exception("Empty response received") # print(f"completion_response: {complete_response}") -# except InvalidRequestError as e: +# except InvalidRequestError as e: # pass # except Exception as e: # pytest.fail(f"Error occurred: {e}") + def test_completion_replicate_stream_bad_key(): try: api_key = "bad-key" @@ -603,11 +650,11 @@ def test_completion_replicate_stream_bad_key(): }, ] response = completion( - model="replicate/meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", - messages=messages, - stream=True, - max_tokens=50, - api_key=api_key + model="replicate/meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", + messages=messages, + stream=True, + max_tokens=50, + api_key=api_key, ) complete_response = "" # Add any assertions here to check the response @@ -616,7 +663,7 @@ def test_completion_replicate_stream_bad_key(): if finished: break complete_response += chunk - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") except AuthenticationError as e: @@ -625,14 +672,21 @@ def test_completion_replicate_stream_bad_key(): except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_replicate_stream_bad_key() + def test_completion_bedrock_claude_stream(): try: - litellm.set_verbose=False + litellm.set_verbose = False response = completion( - model="bedrock/anthropic.claude-instant-v1", - messages=[{"role": "user", "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?"}], + model="bedrock/anthropic.claude-instant-v1", + messages=[ + { + "role": "user", + "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?", + } + ], temperature=1, max_tokens=20, stream=True, @@ -650,7 +704,7 @@ def test_completion_bedrock_claude_stream(): break if has_finish_reason is False: raise Exception("finish reason not set for last chunk") - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") except RateLimitError: @@ -658,14 +712,21 @@ def test_completion_bedrock_claude_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_bedrock_claude_stream() + +# test_completion_bedrock_claude_stream() + def test_completion_bedrock_ai21_stream(): try: - litellm.set_verbose=False + litellm.set_verbose = False response = completion( - model="bedrock/ai21.j2-mid-v1", - messages=[{"role": "user", "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?"}], + model="bedrock/ai21.j2-mid-v1", + messages=[ + { + "role": "user", + "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?", + } + ], temperature=1, max_tokens=20, stream=True, @@ -683,7 +744,7 @@ def test_completion_bedrock_ai21_stream(): break if has_finish_reason is False: raise Exception("finish reason not set for last chunk") - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") except RateLimitError: @@ -691,53 +752,72 @@ def test_completion_bedrock_ai21_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_completion_bedrock_ai21_stream() -def test_sagemaker_weird_response(): + +def test_sagemaker_weird_response(): """ When the stream ends, flush any remaining holding chunks. """ - try: + try: chunk = """[INST] Hey, how's it going? [/INST] I'm doing well, thanks for asking! How about you? Is there anything you'd like to chat about or ask? I'm here to help with any questions you might have.""" - logging_obj = litellm.Logging(model="berri-benchmarking-Llama-2-70b-chat-hf-4", messages=messages, stream=True, litellm_call_id="1234", function_id="function_id", call_type="acompletion", start_time=time.time()) - response = litellm.CustomStreamWrapper(completion_stream=chunk, model="berri-benchmarking-Llama-2-70b-chat-hf-4", custom_llm_provider="sagemaker", logging_obj=logging_obj) + logging_obj = litellm.Logging( + model="berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=messages, + stream=True, + litellm_call_id="1234", + function_id="function_id", + call_type="acompletion", + start_time=time.time(), + ) + response = litellm.CustomStreamWrapper( + completion_stream=chunk, + model="berri-benchmarking-Llama-2-70b-chat-hf-4", + custom_llm_provider="sagemaker", + logging_obj=logging_obj, + ) complete_response = "" for chunk in response: complete_response += chunk["choices"][0]["delta"]["content"] assert len(complete_response) > 0 - except Exception as e: + except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") + + # test_sagemaker_weird_response() + @pytest.mark.asyncio async def test_sagemaker_streaming_async(): - try: + try: messages = [{"role": "user", "content": "Hey, how's it going?"}] - litellm.set_verbose=True + litellm.set_verbose = True response = await litellm.acompletion( - model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", messages=messages, max_tokens=100, temperature=0.7, stream=True, ) - # Add any assertions here to check the response - complete_response = "" + # Add any assertions here to check the response + complete_response = "" async for chunk in response: - complete_response += chunk.choices[0].delta.content or "" + complete_response += chunk.choices[0].delta.content or "" print(f"complete_response: {complete_response}") assert len(complete_response) > 0 - except Exception as e: + except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") + # def test_completion_sagemaker_stream(): # try: # response = completion( -# model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b", +# model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b", # messages=messages, # temperature=0.2, # max_tokens=80, @@ -754,9 +834,9 @@ async def test_sagemaker_streaming_async(): # complete_response += chunk # if has_finish_reason is False: # raise Exception("finish reason not set for last chunk") -# if complete_response.strip() == "": +# if complete_response.strip() == "": # raise Exception("Empty response received") -# except InvalidRequestError as e: +# except InvalidRequestError as e: # pass # except Exception as e: # pytest.fail(f"Error occurred: {e}") @@ -775,24 +855,24 @@ async def test_sagemaker_streaming_async(): # complete_response += chunk # if finished: # break -# if complete_response.strip() == "": +# if complete_response.strip() == "": # raise Exception("Empty response received") # except: # pytest.fail(f"error occurred: {traceback.format_exc()}") # test_maritalk_streaming() # test on openai completion call + # # test on ai21 completion call def ai21_completion_call(): try: - messages=[{ - "role": "system", - "content": "You are an all-knowing oracle", - }, - { - "role": "user", - "content": "What is the meaning of the Universe?" - }] + messages = [ + { + "role": "system", + "content": "You are an all-knowing oracle", + }, + {"role": "user", "content": "What is the meaning of the Universe?"}, + ] response = completion( model="j2-ultra", messages=messages, stream=True, max_tokens=500 ) @@ -808,14 +888,16 @@ def ai21_completion_call(): break if has_finished is False: raise Exception("finished reason missing from final chunk") - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") except: pytest.fail(f"error occurred: {traceback.format_exc()}") + # ai21_completion_call() + def ai21_completion_call_bad_key(): try: api_key = "bad-key" @@ -830,20 +912,22 @@ def ai21_completion_call_bad_key(): if finished: break complete_response += chunk - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") except: pytest.fail(f"error occurred: {traceback.format_exc()}") + # ai21_completion_call_bad_key() + def hf_test_completion_tgi_stream(): try: response = completion( - model = 'huggingface/HuggingFaceH4/zephyr-7b-beta', - messages = [{ "content": "Hello, how are you?","role": "user"}], - stream=True + model="huggingface/HuggingFaceH4/zephyr-7b-beta", + messages=[{"content": "Hello, how are you?", "role": "user"}], + stream=True, ) # Add any assertions here to check the response print(f"response: {response}") @@ -854,11 +938,13 @@ def hf_test_completion_tgi_stream(): complete_response += chunk if finished: break - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") except Exception as e: pytest.fail(f"Error occurred: {e}") + + # hf_test_completion_tgi_stream() # def test_completion_aleph_alpha(): @@ -878,7 +964,7 @@ def hf_test_completion_tgi_stream(): # break # if has_finished is False: # raise Exception("finished reason missing from final chunk") -# if complete_response.strip() == "": +# if complete_response.strip() == "": # raise Exception("Empty response received") # except Exception as e: # pytest.fail(f"Error occurred: {e}") @@ -903,23 +989,22 @@ def hf_test_completion_tgi_stream(): # break # if has_finished is False: # raise Exception("finished reason missing from final chunk") -# if complete_response.strip() == "": +# if complete_response.strip() == "": # raise Exception("Empty response received") -# except InvalidRequestError as e: +# except InvalidRequestError as e: # pass # except Exception as e: # pytest.fail(f"Error occurred: {e}") # test_completion_aleph_alpha_bad_key() + # test on openai completion call def test_openai_chat_completion_call(): try: litellm.set_verbose = False print(f"making openai chat completion call") - response = completion( - model="gpt-3.5-turbo", messages=messages, stream=True - ) + response = completion(model="gpt-3.5-turbo", messages=messages, stream=True) complete_response = "" start_time = time.time() for idx, chunk in enumerate(response): @@ -929,27 +1014,34 @@ def test_openai_chat_completion_call(): break complete_response += chunk # print(f'complete_chunk: {complete_response}') - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"complete response: {complete_response}") except: print(f"error occurred: {traceback.format_exc()}") pass + # test_openai_chat_completion_call() + def test_openai_chat_completion_complete_response_call(): try: complete_response = completion( - model="gpt-3.5-turbo", messages=messages, stream=True, complete_response=True + model="gpt-3.5-turbo", + messages=messages, + stream=True, + complete_response=True, ) print(f"complete response: {complete_response}") except: print(f"error occurred: {traceback.format_exc()}") pass + # test_openai_chat_completion_complete_response_call() + def test_openai_text_completion_call(): try: litellm.set_verbose = True @@ -965,15 +1057,17 @@ def test_openai_text_completion_call(): if finished: break # print(f'complete_chunk: {complete_response}') - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"complete response: {complete_response}") except: print(f"error occurred: {traceback.format_exc()}") pass + # test_openai_text_completion_call() + # # test on together ai completion call - starcoder def test_together_ai_completion_call_mistral(): try: @@ -1003,6 +1097,7 @@ def test_together_ai_completion_call_mistral(): print(f"error occurred: {traceback.format_exc()}") pass + def test_together_ai_completion_call_starcoder_bad_key(): try: api_key = "bad-key" @@ -1011,7 +1106,7 @@ def test_together_ai_completion_call_starcoder_bad_key(): model="together_ai/bigcode/starcoder", messages=messages, stream=True, - api_key=api_key + api_key=api_key, ) complete_response = "" has_finish_reason = False @@ -1032,9 +1127,11 @@ def test_together_ai_completion_call_starcoder_bad_key(): print(f"error occurred: {traceback.format_exc()}") pass -# test_together_ai_completion_call_starcoder_bad_key() + +# test_together_ai_completion_call_starcoder_bad_key() #### Test Function calling + streaming #### + def test_completion_openai_with_functions(): function1 = [ { @@ -1054,16 +1151,11 @@ def test_completion_openai_with_functions(): } ] try: - litellm.set_verbose=False + litellm.set_verbose = False response = completion( - model="gpt-3.5-turbo-1106", - messages=[ - { - "role": "user", - "content": "what's the weather in SF" - } - ], - functions=function1, + model="gpt-3.5-turbo-1106", + messages=[{"role": "user", "content": "what's the weather in SF"}], + functions=function1, stream=True, ) # Add any assertions here to check the response @@ -1076,9 +1168,12 @@ def test_completion_openai_with_functions(): print(chunk["choices"][0]["delta"]["content"]) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_openai_with_functions() #### Test Async streaming #### + # # test on ai21 completion call async def ai21_async_completion_call(): try: @@ -1096,19 +1191,25 @@ async def ai21_async_completion_call(): break complete_response += chunk idx += 1 - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"complete response: {complete_response}") except: print(f"error occurred: {traceback.format_exc()}") pass + # asyncio.run(ai21_async_completion_call()) + async def completion_call(): try: response = completion( - model="gpt-3.5-turbo", messages=messages, stream=True, logger_fn=logger_fn, max_tokens=10 + model="gpt-3.5-turbo", + messages=messages, + stream=True, + logger_fn=logger_fn, + max_tokens=10, ) print(f"response: {response}") complete_response = "" @@ -1121,13 +1222,14 @@ async def completion_call(): break complete_response += chunk idx += 1 - if complete_response.strip() == "": + if complete_response.strip() == "": raise Exception("Empty response received") print(f"complete response: {complete_response}") except: print(f"error occurred: {traceback.format_exc()}") pass + # asyncio.run(completion_call()) #### Test Function Calling + Streaming #### @@ -1145,55 +1247,54 @@ final_openai_function_call_example = { "content": None, "function_call": { "name": "get_current_weather", - "arguments": "{\n \"location\": \"Boston, MA\"\n}" - } + "arguments": '{\n "location": "Boston, MA"\n}', + }, }, - "finish_reason": "function_call" + "finish_reason": "function_call", } ], - "usage": { - "prompt_tokens": 82, - "completion_tokens": 18, - "total_tokens": 100 - } + "usage": {"prompt_tokens": 82, "completion_tokens": 18, "total_tokens": 100}, } function_calling_output_structure = { - "id": str, - "object": str, - "created": int, - "model": str, - "choices": [ - { - "index": int, - "message": { - "role": str, - "content": (type(None), str), - "function_call": { - "name": str, - "arguments": str - } - }, - "finish_reason": str - } - ], - "usage": { - "prompt_tokens": int, - "completion_tokens": int, - "total_tokens": int + "id": str, + "object": str, + "created": int, + "model": str, + "choices": [ + { + "index": int, + "message": { + "role": str, + "content": (type(None), str), + "function_call": {"name": str, "arguments": str}, + }, + "finish_reason": str, } - } + ], + "usage": {"prompt_tokens": int, "completion_tokens": int, "total_tokens": int}, +} + def validate_final_structure(item, structure=function_calling_output_structure): if isinstance(item, list): if not all(validate_final_structure(i, structure[0]) for i in item): - return Exception("Function calling final output doesn't match expected output format") + return Exception( + "Function calling final output doesn't match expected output format" + ) elif isinstance(item, dict): - if not all(k in item and validate_final_structure(item[k], v) for k, v in structure.items()): - return Exception("Function calling final output doesn't match expected output format") + if not all( + k in item and validate_final_structure(item[k], v) + for k, v in structure.items() + ): + return Exception( + "Function calling final output doesn't match expected output format" + ) else: if not isinstance(item, structure): - return Exception("Function calling final output doesn't match expected output format") + return Exception( + "Function calling final output doesn't match expected output format" + ) return True @@ -1208,16 +1309,14 @@ first_openai_function_call_example = { "delta": { "role": "assistant", "content": None, - "function_call": { - "name": "get_current_weather", - "arguments": "" - } + "function_call": {"name": "get_current_weather", "arguments": ""}, }, - "finish_reason": None + "finish_reason": None, } - ] + ], } + def validate_first_function_call_chunk_structure(item): if not isinstance(item, dict): raise Exception("Incorrect format") @@ -1245,7 +1344,7 @@ def validate_first_function_call_chunk_structure(item): for key in required_keys_in_delta: if key not in choice["delta"]: raise Exception("Incorrect format") - + if not isinstance(choice["delta"]["function_call"], dict): raise Exception("Incorrect format") @@ -1256,6 +1355,7 @@ def validate_first_function_call_chunk_structure(item): return True + second_function_call_chunk_format = { "id": "chatcmpl-7zVRoE5HjHYsCMaVSNgOjzdhbS3P0", "object": "chat.completion.chunk", @@ -1264,14 +1364,10 @@ second_function_call_chunk_format = { "choices": [ { "index": 0, - "delta": { - "function_call": { - "arguments": "{\n" - } - }, - "finish_reason": None + "delta": {"function_call": {"arguments": "{\n"}}, + "finish_reason": None, } - ] + ], } @@ -1295,7 +1391,10 @@ def validate_second_function_call_chunk_structure(data): if key not in choice: raise Exception("Incorrect format") - if "function_call" not in choice["delta"] or "arguments" not in choice["delta"]["function_call"]: + if ( + "function_call" not in choice["delta"] + or "arguments" not in choice["delta"]["function_call"] + ): raise Exception("Incorrect format") return True @@ -1306,13 +1405,7 @@ final_function_call_chunk_example = { "object": "chat.completion.chunk", "created": 1694893248, "model": "gpt-3.5-turbo-0613", - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "function_call" - } - ] + "choices": [{"index": 0, "delta": {}, "finish_reason": "function_call"}], } @@ -1338,20 +1431,21 @@ def validate_final_function_call_chunk_structure(data): return True + def streaming_and_function_calling_format_tests(idx, chunk): - extracted_chunk = "" + extracted_chunk = "" finished = False print(f"idx: {idx}") print(f"chunk: {chunk}") decision = False - if idx == 0: # ensure role assistant is set + if idx == 0: # ensure role assistant is set decision = validate_first_function_call_chunk_structure(chunk) role = chunk["choices"][0]["delta"]["role"] assert role == "assistant" - elif idx != 0: # second chunk + elif idx != 0: # second chunk try: decision = validate_second_function_call_chunk_structure(data=chunk) - except: # check if it's the last chunk (returns an empty delta {} ) + except: # check if it's the last chunk (returns an empty delta {} ) decision = validate_final_function_call_chunk_structure(data=chunk) finished = True if "content" in chunk["choices"][0]["delta"]: @@ -1360,6 +1454,7 @@ def streaming_and_function_calling_format_tests(idx, chunk): raise Exception("incorrect format") return extracted_chunk, finished + # def test_openai_streaming_and_function_calling(): # function1 = [ # { @@ -1388,10 +1483,11 @@ def streaming_and_function_calling_format_tests(idx, chunk): # streaming_and_function_calling_format_tests(idx=idx, chunk=chunk) # except Exception as e: # pytest.fail(f"Error occurred: {e}") -# raise e +# raise e # test_openai_streaming_and_function_calling() + def test_success_callback_streaming(): def success_callback(kwargs, completion_response, start_time, end_time): print( @@ -1404,20 +1500,20 @@ def test_success_callback_streaming(): } ) - litellm.success_callback = [success_callback] messages = [{"role": "user", "content": "hello"}] print("TESTING LITELLM COMPLETION CALL") response = litellm.completion( - model="j2-light", - messages=messages, stream=True, + model="j2-light", + messages=messages, + stream=True, max_tokens=5, ) print(response) - for chunk in response: print(chunk["choices"][0]) -# test_success_callback_streaming() \ No newline at end of file + +# test_success_callback_streaming() diff --git a/litellm/tests/test_supabase_integration.py b/litellm/tests/test_supabase_integration.py index b07505c23..e92a1de7a 100644 --- a/litellm/tests/test_supabase_integration.py +++ b/litellm/tests/test_supabase_integration.py @@ -4,7 +4,9 @@ import sys, os import traceback import pytest -sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path import litellm from litellm import embedding, completion @@ -15,6 +17,7 @@ litellm.failure_callback = ["supabase"] litellm.set_verbose = False + def test_supabase_logging(): try: response = completion( @@ -27,41 +30,46 @@ def test_supabase_logging(): except Exception as e: print(e) + # test_supabase_logging() + def test_acompletion_sync(): import asyncio import time + async def completion_call(): try: response = await litellm.acompletion( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "write a poem"}], + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "write a poem"}], max_tokens=10, stream=True, user="ishaanStreamingUser", - timeout=5 + timeout=5, ) complete_response = "" start_time = time.time() async for chunk in response: chunk_time = time.time() - #print(chunk) + # print(chunk) complete_response += chunk["choices"][0]["delta"].get("content", "") - #print(complete_response) - #print(f"time since initial request: {chunk_time - start_time:.5f}") + # print(complete_response) + # print(f"time since initial request: {chunk_time - start_time:.5f}") if chunk["choices"][0].get("finish_reason", None) != None: print("🤗🤗🤗 DONE") return - except litellm.Timeout as e: + except litellm.Timeout as e: pass except: print(f"error occurred: {traceback.format_exc()}") pass asyncio.run(completion_call()) + + # test_acompletion_sync() @@ -69,7 +77,3 @@ def test_acompletion_sync(): litellm.input_callback = [] litellm.success_callback = [] litellm.failure_callback = [] - - - - diff --git a/litellm/tests/test_text_completion.py b/litellm/tests/test_text_completion.py index f75bd2f7f..2138c60ea 100644 --- a/litellm/tests/test_text_completion.py +++ b/litellm/tests/test_text_completion.py @@ -10,13 +10,2670 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest import litellm -from litellm import embedding, completion, text_completion, completion_cost, atext_completion +from litellm import ( + embedding, + completion, + text_completion, + completion_cost, + atext_completion, +) from litellm import RateLimitError -token_prompt = [[32, 2043, 32, 329, 4585, 262, 1644, 14, 34, 3705, 319, 616, 47551, 30, 930, 19219, 284, 1949, 284, 787, 428, 355, 1790, 355, 1744, 981, 1390, 3307, 2622, 13, 220, 198, 198, 40, 423, 587, 351, 616, 41668, 32682, 329, 718, 812, 13, 376, 666, 32682, 468, 281, 4697, 6621, 11, 356, 1183, 869, 607, 25737, 11, 508, 318, 2579, 290, 468, 257, 642, 614, 1468, 1200, 13, 314, 373, 612, 262, 1110, 25737, 373, 287, 4827, 290, 14801, 373, 4642, 11, 673, 318, 616, 41803, 13, 2399, 2104, 1641, 468, 6412, 284, 502, 355, 465, 38074, 494, 1201, 1110, 352, 13, 314, 716, 407, 2910, 475, 356, 389, 1641, 11, 673, 3848, 502, 38074, 494, 290, 356, 423, 3993, 13801, 11, 26626, 11864, 11, 3503, 13, 220, 198, 198, 17, 812, 2084, 25737, 373, 287, 14321, 422, 2563, 13230, 13, 21051, 11, 2356, 25542, 11, 290, 47482, 897, 547, 607, 1517, 13, 1375, 550, 257, 5110, 14608, 290, 262, 1641, 7723, 1637, 284, 3758, 607, 284, 14321, 290, 477, 8389, 257, 7269, 284, 1011, 1337, 286, 14801, 13, 383, 5156, 338, 9955, 11, 25737, 338, 13850, 11, 468, 257, 47973, 14, 9979, 2762, 1693, 290, 373, 503, 286, 3240, 329, 362, 1933, 523, 339, 2492, 470, 612, 329, 477, 286, 428, 13, 220, 198, 198, 3347, 10667, 5223, 503, 706, 513, 1528, 11, 23630, 673, 373, 366, 38125, 290, 655, 2622, 257, 3338, 8399, 1911, 314, 2298, 607, 510, 11, 1011, 607, 284, 607, 2156, 11, 290, 673, 3393, 2925, 284, 7523, 20349, 290, 4144, 257, 6099, 13, 314, 836, 470, 892, 20349, 318, 257, 2563, 290, 716, 845, 386, 12, 66, 1236, 571, 292, 3584, 314, 836, 470, 7523, 11, 475, 326, 373, 407, 5035, 6402, 314, 655, 6497, 607, 510, 422, 14321, 13, 220, 198, 198, 32, 1285, 1568, 673, 373, 6294, 329, 3013, 24707, 287, 262, 12436, 1539, 819, 5722, 329, 852, 604, 1933, 2739, 11, 39398, 607, 1097, 5059, 981, 1029, 290, 318, 852, 16334, 329, 720, 1120, 74, 422, 15228, 278, 656, 257, 2156, 11, 290, 373, 12165, 503, 286, 376, 666, 32682, 338, 584, 6621, 338, 2156, 329, 32012, 262, 14595, 373, 30601, 510, 290, 2491, 357, 7091, 373, 1029, 8, 290, 262, 2104, 34624, 373, 46432, 1268, 1961, 422, 1660, 2465, 780, 8168, 2073, 1625, 1363, 329, 807, 2250, 13, 720, 1238, 11, 830, 286, 2465, 290, 5875, 5770, 511, 2156, 5096, 5017, 340, 13, 220, 198, 198, 2504, 373, 477, 938, 614, 13, 1119, 1053, 587, 287, 511, 649, 2156, 319, 511, 898, 329, 546, 718, 1933, 13, 554, 3389, 673, 1444, 34020, 290, 531, 511, 8744, 373, 4423, 572, 780, 673, 1422, 470, 423, 262, 1637, 780, 41646, 338, 37751, 1392, 32621, 510, 290, 1422, 470, 467, 832, 13, 679, 3432, 511, 2739, 8744, 9024, 492, 257, 2472, 286, 720, 4059, 13, 314, 1807, 340, 373, 13678, 306, 5789, 475, 4030, 616, 5422, 4423, 13, 1439, 468, 587, 5897, 1201, 13, 220, 198, 198, 7571, 2745, 2084, 11, 673, 1965, 502, 284, 8804, 617, 1637, 284, 651, 38464, 329, 399, 8535, 13, 3226, 1781, 314, 1101, 407, 1016, 284, 1309, 616, 41803, 393, 6621, 467, 14720, 11, 645, 2300, 644, 318, 1016, 319, 4306, 11, 523, 314, 910, 314, 1183, 307, 625, 379, 642, 13, 314, 1392, 572, 670, 1903, 290, 651, 612, 379, 362, 25, 2231, 13, 314, 1282, 287, 1262, 616, 13952, 1994, 11, 2513, 287, 11, 766, 399, 8535, 2712, 351, 36062, 287, 262, 5228, 11, 25737, 3804, 503, 319, 262, 18507, 11, 290, 16914, 319, 262, 6891, 3084, 13, 8989, 2406, 422, 257, 1641, 47655, 351, 13230, 11, 314, 760, 644, 16914, 3073, 588, 13, 314, 836, 470, 760, 703, 881, 340, 373, 11, 475, 314, 714, 423, 23529, 276, 340, 510, 290, 5901, 616, 18057, 351, 340, 13, 314, 6810, 19772, 2024, 8347, 287, 262, 2166, 2119, 290, 399, 8535, 373, 287, 3294, 11685, 286, 8242, 290, 607, 7374, 15224, 13, 383, 4894, 373, 572, 13, 383, 2156, 373, 3863, 2319, 37, 532, 340, 373, 1542, 2354, 13, 220, 198, 198, 40, 1718, 399, 8535, 284, 616, 1097, 11, 290, 1444, 16679, 329, 281, 22536, 355, 314, 373, 12008, 25737, 373, 14904, 2752, 13, 220, 314, 1422, 470, 765, 284, 10436, 290, 22601, 503, 399, 8535, 523, 314, 9658, 287, 262, 1097, 290, 1309, 607, 711, 319, 616, 3072, 1566, 262, 22536, 5284, 13, 3226, 1781, 1644, 290, 32084, 3751, 510, 355, 880, 13, 314, 4893, 262, 3074, 290, 780, 399, 8535, 338, 9955, 318, 503, 286, 3240, 1762, 11, 34020, 14, 44, 4146, 547, 1444, 13, 1649, 484, 5284, 484, 547, 5897, 290, 4692, 11, 1422, 470, 1107, 1561, 11, 1718, 399, 8535, 11, 290, 1297, 502, 284, 467, 1363, 13, 220, 198, 198, 2025, 1711, 1568, 314, 651, 1363, 290, 41668, 32682, 7893, 502, 644, 314, 1053, 1760, 13, 314, 4893, 2279, 284, 683, 290, 477, 339, 550, 373, 8993, 329, 502, 13, 18626, 262, 2104, 1641, 1541, 2993, 290, 547, 28674, 379, 502, 329, 644, 314, 550, 1760, 13, 18626, 314, 373, 366, 448, 286, 1627, 290, 8531, 1, 780, 314, 1444, 16679, 878, 4379, 611, 673, 373, 1682, 31245, 6, 278, 780, 340, 2900, 503, 673, 373, 655, 47583, 503, 422, 262, 16914, 13, 775, 8350, 329, 2250, 290, 314, 1364, 290, 3377, 262, 1755, 379, 616, 1266, 1545, 338, 2156, 290, 16896, 477, 1755, 13, 314, 3521, 470, 5412, 340, 477, 523, 314, 2900, 616, 3072, 572, 290, 3088, 284, 8960, 290, 655, 9480, 866, 13, 2011, 1266, 1545, 373, 510, 477, 1755, 351, 502, 11, 5149, 502, 314, 750, 2147, 2642, 11, 290, 314, 1101, 8788, 13, 220, 198, 198, 40, 1210, 616, 3072, 319, 290, 314, 550, 6135, 13399, 14, 37348, 1095, 13, 31515, 11, 34020, 11, 47551, 11, 41668, 32682, 11, 290, 511, 7083, 1641, 1866, 24630, 502, 13, 1119, 389, 2282, 314, 20484, 607, 1204, 11, 20484, 399, 8535, 338, 1204, 11, 925, 2279, 517, 8253, 621, 340, 2622, 284, 307, 11, 925, 340, 1171, 618, 340, 373, 257, 366, 17989, 14669, 1600, 290, 20484, 25737, 338, 8395, 286, 1683, 1972, 20750, 393, 1719, 10804, 286, 607, 1200, 757, 11, 4844, 286, 606, 1683, 765, 284, 766, 502, 757, 290, 314, 481, 1239, 766, 616, 41803, 757, 11, 290, 484, 765, 502, 284, 1414, 329, 25737, 338, 7356, 6314, 290, 20889, 502, 329, 262, 32084, 1339, 290, 7016, 12616, 13, 198, 198, 40, 716, 635, 783, 2060, 13, 1406, 319, 1353, 286, 6078, 616, 1266, 1545, 286, 838, 812, 357, 69, 666, 32682, 828, 314, 481, 4425, 616, 7962, 314, 550, 351, 683, 11, 644, 314, 3177, 616, 1641, 11, 290, 616, 399, 8535, 13, 198, 198, 40, 4988, 1254, 12361, 13, 314, 423, 12361, 9751, 284, 262, 966, 810, 314, 1101, 7960, 2130, 318, 1016, 284, 1282, 651, 366, 260, 18674, 1, 319, 502, 329, 644, 314, 750, 13, 314, 460, 470, 4483, 13, 314, 423, 2626, 767, 8059, 422, 340, 13, 314, 1101, 407, 11029, 329, 7510, 13, 314, 423, 11668, 739, 616, 2951, 13, 314, 1053, 550, 807, 50082, 12, 12545, 287, 734, 2745, 13, 1629, 717, 314, 2936, 523, 6563, 287, 616, 2551, 475, 355, 262, 1528, 467, 416, 314, 1101, 3612, 3863, 484, 547, 826, 290, 314, 815, 423, 10667, 319, 607, 878, 4585, 16679, 290, 852, 5306, 3019, 992, 13, 314, 836, 470, 1337, 546, 25737, 7471, 11, 475, 314, 750, 18344, 257, 642, 614, 1468, 1200, 1497, 422, 607, 3397, 290, 314, 1254, 12361, 546, 340, 13, 314, 760, 2130, 287, 262, 1641, 481, 1011, 607, 287, 11, 475, 340, 338, 1239, 588, 852, 351, 534, 3397, 13, 1375, 481, 1663, 510, 20315, 278, 502, 329, 340, 290, 477, 314, 1053, 1683, 1760, 318, 1842, 607, 355, 616, 898, 13, 220, 198, 198, 22367, 11, 317, 2043, 32, 30, 4222, 1037, 502, 13, 383, 14934, 318, 6600, 502, 6776, 13, 220, 198, 24361, 25, 1148, 428, 2642, 30, 198, 33706, 25, 645], [32, 2043, 32, 329, 4585, 262, 1644, 14, 34, 3705, 319, 616, 47551, 30, 930, 19219, 284, 1949, 284, 787, 428, 355, 1790, 355, 1744, 981, 1390, 3307, 2622, 13, 220, 198, 198, 40, 423, 587, 351, 616, 41668, 32682, 329, 718, 812, 13, 376, 666, 32682, 468, 281, 4697, 6621, 11, 356, 1183, 869, 607, 25737, 11, 508, 318, 2579, 290, 468, 257, 642, 614, 1468, 1200, 13, 314, 373, 612, 262, 1110, 25737, 373, 287, 4827, 290, 14801, 373, 4642, 11, 673, 318, 616, 41803, 13, 2399, 2104, 1641, 468, 6412, 284, 502, 355, 465, 38074, 494, 1201, 1110, 352, 13, 314, 716, 407, 2910, 475, 356, 389, 1641, 11, 673, 3848, 502, 38074, 494, 290, 356, 423, 3993, 13801, 11, 26626, 11864, 11, 3503, 13, 220, 198, 198, 17, 812, 2084, 25737, 373, 287, 14321, 422, 2563, 13230, 13, 21051, 11, 2356, 25542, 11, 290, 47482, 897, 547, 607, 1517, 13, 1375, 550, 257, 5110, 14608, 290, 262, 1641, 7723, 1637, 284, 3758, 607, 284, 14321, 290, 477, 8389, 257, 7269, 284, 1011, 1337, 286, 14801, 13, 383, 5156, 338, 9955, 11, 25737, 338, 13850, 11, 468, 257, 47973, 14, 9979, 2762, 1693, 290, 373, 503, 286, 3240, 329, 362, 1933, 523, 339, 2492, 470, 612, 329, 477, 286, 428, 13, 220, 198, 198, 3347, 10667, 5223, 503, 706, 513, 1528, 11, 23630, 673, 373, 366, 38125, 290, 655, 2622, 257, 3338, 8399, 1911, 314, 2298, 607, 510, 11, 1011, 607, 284, 607, 2156, 11, 290, 673, 3393, 2925, 284, 7523, 20349, 290, 4144, 257, 6099, 13, 314, 836, 470, 892, 20349, 318, 257, 2563, 290, 716, 845, 386, 12, 66, 1236, 571, 292, 3584, 314, 836, 470, 7523, 11, 475, 326, 373, 407, 5035, 6402, 314, 655, 6497, 607, 510, 422, 14321, 13, 220, 198, 198, 32, 1285, 1568, 673, 373, 6294, 329, 3013, 24707, 287, 262, 12436, 1539, 819, 5722, 329, 852, 604, 1933, 2739, 11, 39398, 607, 1097, 5059, 981, 1029, 290, 318, 852, 16334, 329, 720, 1120, 74, 422, 15228, 278, 656, 257, 2156, 11, 290, 373, 12165, 503, 286, 376, 666, 32682, 338, 584, 6621, 338, 2156, 329, 32012, 262, 14595, 373, 30601, 510, 290, 2491, 357, 7091, 373, 1029, 8, 290, 262, 2104, 34624, 373, 46432, 1268, 1961, 422, 1660, 2465, 780, 8168, 2073, 1625, 1363, 329, 807, 2250, 13, 720, 1238, 11, 830, 286, 2465, 290, 5875, 5770, 511, 2156, 5096, 5017, 340, 13, 220, 198, 198, 2504, 373, 477, 938, 614, 13, 1119, 1053, 587, 287, 511, 649, 2156, 319, 511, 898, 329, 546, 718, 1933, 13, 554, 3389, 673, 1444, 34020, 290, 531, 511, 8744, 373, 4423, 572, 780, 673, 1422, 470, 423, 262, 1637, 780, 41646, 338, 37751, 1392, 32621, 510, 290, 1422, 470, 467, 832, 13, 679, 3432, 511, 2739, 8744, 9024, 492, 257, 2472, 286, 720, 4059, 13, 314, 1807, 340, 373, 13678, 306, 5789, 475, 4030, 616, 5422, 4423, 13, 1439, 468, 587, 5897, 1201, 13, 220, 198, 198, 7571, 2745, 2084, 11, 673, 1965, 502, 284, 8804, 617, 1637, 284, 651, 38464, 329, 399, 8535, 13, 3226, 1781, 314, 1101, 407, 1016, 284, 1309, 616, 41803, 393, 6621, 467, 14720, 11, 645, 2300, 644, 318, 1016, 319, 4306, 11, 523, 314, 910, 314, 1183, 307, 625, 379, 642, 13, 314, 1392, 572, 670, 1903, 290, 651, 612, 379, 362, 25, 2231, 13, 314, 1282, 287, 1262, 616, 13952, 1994, 11, 2513, 287, 11, 766, 399, 8535, 2712, 351, 36062, 287, 262, 5228, 11, 25737, 3804, 503, 319, 262, 18507, 11, 290, 16914, 319, 262, 6891, 3084, 13, 8989, 2406, 422, 257, 1641, 47655, 351, 13230, 11, 314, 760, 644, 16914, 3073, 588, 13, 314, 836, 470, 760, 703, 881, 340, 373, 11, 475, 314, 714, 423, 23529, 276, 340, 510, 290, 5901, 616, 18057, 351, 340, 13, 314, 6810, 19772, 2024, 8347, 287, 262, 2166, 2119, 290, 399, 8535, 373, 287, 3294, 11685, 286, 8242, 290, 607, 7374, 15224, 13, 383, 4894, 373, 572, 13, 383, 2156, 373, 3863, 2319, 37, 532, 340, 373, 1542, 2354, 13, 220, 198, 198, 40, 1718, 399, 8535, 284, 616, 1097, 11, 290, 1444, 16679, 329, 281, 22536, 355, 314, 373, 12008, 25737, 373, 14904, 2752, 13, 220, 314, 1422, 470, 765, 284, 10436, 290, 22601, 503, 399, 8535, 523, 314, 9658, 287, 262, 1097, 290, 1309, 607, 711, 319, 616, 3072, 1566, 262, 22536, 5284, 13, 3226, 1781, 1644, 290, 32084, 3751, 510, 355, 880, 13, 314, 4893, 262, 3074, 290, 780, 399, 8535, 338, 9955, 318, 503, 286, 3240, 1762, 11, 34020, 14, 44, 4146, 547, 1444, 13, 1649, 484, 5284, 484, 547, 5897, 290, 4692, 11, 1422, 470, 1107, 1561, 11, 1718, 399, 8535, 11, 290, 1297, 502, 284, 467, 1363, 13, 220, 198, 198, 2025, 1711, 1568, 314, 651, 1363, 290, 41668, 32682, 7893, 502, 644, 314, 1053, 1760, 13, 314, 4893, 2279, 284, 683, 290, 477, 339, 550, 373, 8993, 329, 502, 13, 18626, 262, 2104, 1641, 1541, 2993, 290, 547, 28674, 379, 502, 329, 644, 314, 550, 1760, 13, 18626, 314, 373, 366, 448, 286, 1627, 290, 8531, 1, 780, 314, 1444, 16679, 878, 4379, 611, 673, 373, 1682, 31245, 6, 278, 780, 340, 2900, 503, 673, 373, 655, 47583, 503, 422, 262, 16914, 13, 775, 8350, 329, 2250, 290, 314, 1364, 290, 3377, 262, 1755, 379, 616, 1266, 1545, 338, 2156, 290, 16896, 477, 1755, 13, 314, 3521, 470, 5412, 340, 477, 523, 314, 2900, 616, 3072, 572, 290, 3088, 284, 8960, 290, 655, 9480, 866, 13, 2011, 1266, 1545, 373, 510, 477, 1755, 351, 502, 11, 5149, 502, 314, 750, 2147, 2642, 11, 290, 314, 1101, 8788, 13, 220, 198, 198, 40, 1210, 616, 3072, 319, 290, 314, 550, 6135, 13399, 14, 37348, 1095, 13, 31515, 11, 34020, 11, 47551, 11, 41668, 32682, 11, 290, 511, 7083, 1641, 1866, 24630, 502, 13, 1119, 389, 2282, 314, 20484, 607, 1204, 11, 20484, 399, 8535, 338, 1204, 11, 925, 2279, 517, 8253, 621, 340, 2622, 284, 307, 11, 925, 340, 1171, 618, 340, 373, 257, 366, 17989, 14669, 1600, 290, 20484, 25737, 338, 8395, 286, 1683, 1972, 20750, 393, 1719, 10804, 286, 607, 1200, 757, 11, 4844, 286, 606, 1683, 765, 284, 766, 502, 757, 290, 314, 481, 1239, 766, 616, 41803, 757, 11, 290, 484, 765, 502, 284, 1414, 329, 25737, 338, 7356, 6314, 290, 20889, 502, 329, 262, 32084, 1339, 290, 7016, 12616, 13, 198, 198, 40, 716, 635, 783, 2060, 13, 1406, 319, 1353, 286, 6078, 616, 1266, 1545, 286, 838, 812, 357, 69, 666, 32682, 828, 314, 481, 4425, 616, 7962, 314, 550, 351, 683, 11, 644, 314, 3177, 616, 1641, 11, 290, 616, 399, 8535, 13, 198, 198, 40, 4988, 1254, 12361, 13, 314, 423, 12361, 9751, 284, 262, 966, 810, 314, 1101, 7960, 2130, 318, 1016, 284, 1282, 651, 366, 260, 18674, 1, 319, 502, 329, 644, 314, 750, 13, 314, 460, 470, 4483, 13, 314, 423, 2626, 767, 8059, 422, 340, 13, 314, 1101, 407, 11029, 329, 7510, 13, 314, 423, 11668, 739, 616, 2951, 13, 314, 1053, 550, 807, 50082, 12, 12545, 287, 734, 2745, 13, 1629, 717, 314, 2936, 523, 6563, 287, 616, 2551, 475, 355, 262, 1528, 467, 416, 314, 1101, 3612, 3863, 484, 547, 826, 290, 314, 815, 423, 10667, 319, 607, 878, 4585, 16679, 290, 852, 5306, 3019, 992, 13, 314, 836, 470, 1337, 546, 25737, 7471, 11, 475, 314, 750, 18344, 257, 642, 614, 1468, 1200, 1497, 422, 607, 3397, 290, 314, 1254, 12361, 546, 340, 13, 314, 760, 2130, 287, 262, 1641, 481, 1011, 607, 287, 11, 475, 340, 338, 1239, 588, 852, 351, 534, 3397, 13, 1375, 481, 1663, 510, 20315, 278, 502, 329, 340, 290, 477, 314, 1053, 1683, 1760, 318, 1842, 607, 355, 616, 898, 13, 220, 198, 198, 22367, 11, 317, 2043, 32, 30, 4222, 1037, 502, 13, 383, 14934, 318, 6600, 502, 6776, 13, 220, 198, 24361, 25, 1148, 428, 2642, 30, 198, 33706, 25, 3763]] - - +token_prompt = [ + [ + 32, + 2043, + 32, + 329, + 4585, + 262, + 1644, + 14, + 34, + 3705, + 319, + 616, + 47551, + 30, + 930, + 19219, + 284, + 1949, + 284, + 787, + 428, + 355, + 1790, + 355, + 1744, + 981, + 1390, + 3307, + 2622, + 13, + 220, + 198, + 198, + 40, + 423, + 587, + 351, + 616, + 41668, + 32682, + 329, + 718, + 812, + 13, + 376, + 666, + 32682, + 468, + 281, + 4697, + 6621, + 11, + 356, + 1183, + 869, + 607, + 25737, + 11, + 508, + 318, + 2579, + 290, + 468, + 257, + 642, + 614, + 1468, + 1200, + 13, + 314, + 373, + 612, + 262, + 1110, + 25737, + 373, + 287, + 4827, + 290, + 14801, + 373, + 4642, + 11, + 673, + 318, + 616, + 41803, + 13, + 2399, + 2104, + 1641, + 468, + 6412, + 284, + 502, + 355, + 465, + 38074, + 494, + 1201, + 1110, + 352, + 13, + 314, + 716, + 407, + 2910, + 475, + 356, + 389, + 1641, + 11, + 673, + 3848, + 502, + 38074, + 494, + 290, + 356, + 423, + 3993, + 13801, + 11, + 26626, + 11864, + 11, + 3503, + 13, + 220, + 198, + 198, + 17, + 812, + 2084, + 25737, + 373, + 287, + 14321, + 422, + 2563, + 13230, + 13, + 21051, + 11, + 2356, + 25542, + 11, + 290, + 47482, + 897, + 547, + 607, + 1517, + 13, + 1375, + 550, + 257, + 5110, + 14608, + 290, + 262, + 1641, + 7723, + 1637, + 284, + 3758, + 607, + 284, + 14321, + 290, + 477, + 8389, + 257, + 7269, + 284, + 1011, + 1337, + 286, + 14801, + 13, + 383, + 5156, + 338, + 9955, + 11, + 25737, + 338, + 13850, + 11, + 468, + 257, + 47973, + 14, + 9979, + 2762, + 1693, + 290, + 373, + 503, + 286, + 3240, + 329, + 362, + 1933, + 523, + 339, + 2492, + 470, + 612, + 329, + 477, + 286, + 428, + 13, + 220, + 198, + 198, + 3347, + 10667, + 5223, + 503, + 706, + 513, + 1528, + 11, + 23630, + 673, + 373, + 366, + 38125, + 290, + 655, + 2622, + 257, + 3338, + 8399, + 1911, + 314, + 2298, + 607, + 510, + 11, + 1011, + 607, + 284, + 607, + 2156, + 11, + 290, + 673, + 3393, + 2925, + 284, + 7523, + 20349, + 290, + 4144, + 257, + 6099, + 13, + 314, + 836, + 470, + 892, + 20349, + 318, + 257, + 2563, + 290, + 716, + 845, + 386, + 12, + 66, + 1236, + 571, + 292, + 3584, + 314, + 836, + 470, + 7523, + 11, + 475, + 326, + 373, + 407, + 5035, + 6402, + 314, + 655, + 6497, + 607, + 510, + 422, + 14321, + 13, + 220, + 198, + 198, + 32, + 1285, + 1568, + 673, + 373, + 6294, + 329, + 3013, + 24707, + 287, + 262, + 12436, + 1539, + 819, + 5722, + 329, + 852, + 604, + 1933, + 2739, + 11, + 39398, + 607, + 1097, + 5059, + 981, + 1029, + 290, + 318, + 852, + 16334, + 329, + 720, + 1120, + 74, + 422, + 15228, + 278, + 656, + 257, + 2156, + 11, + 290, + 373, + 12165, + 503, + 286, + 376, + 666, + 32682, + 338, + 584, + 6621, + 338, + 2156, + 329, + 32012, + 262, + 14595, + 373, + 30601, + 510, + 290, + 2491, + 357, + 7091, + 373, + 1029, + 8, + 290, + 262, + 2104, + 34624, + 373, + 46432, + 1268, + 1961, + 422, + 1660, + 2465, + 780, + 8168, + 2073, + 1625, + 1363, + 329, + 807, + 2250, + 13, + 720, + 1238, + 11, + 830, + 286, + 2465, + 290, + 5875, + 5770, + 511, + 2156, + 5096, + 5017, + 340, + 13, + 220, + 198, + 198, + 2504, + 373, + 477, + 938, + 614, + 13, + 1119, + 1053, + 587, + 287, + 511, + 649, + 2156, + 319, + 511, + 898, + 329, + 546, + 718, + 1933, + 13, + 554, + 3389, + 673, + 1444, + 34020, + 290, + 531, + 511, + 8744, + 373, + 4423, + 572, + 780, + 673, + 1422, + 470, + 423, + 262, + 1637, + 780, + 41646, + 338, + 37751, + 1392, + 32621, + 510, + 290, + 1422, + 470, + 467, + 832, + 13, + 679, + 3432, + 511, + 2739, + 8744, + 9024, + 492, + 257, + 2472, + 286, + 720, + 4059, + 13, + 314, + 1807, + 340, + 373, + 13678, + 306, + 5789, + 475, + 4030, + 616, + 5422, + 4423, + 13, + 1439, + 468, + 587, + 5897, + 1201, + 13, + 220, + 198, + 198, + 7571, + 2745, + 2084, + 11, + 673, + 1965, + 502, + 284, + 8804, + 617, + 1637, + 284, + 651, + 38464, + 329, + 399, + 8535, + 13, + 3226, + 1781, + 314, + 1101, + 407, + 1016, + 284, + 1309, + 616, + 41803, + 393, + 6621, + 467, + 14720, + 11, + 645, + 2300, + 644, + 318, + 1016, + 319, + 4306, + 11, + 523, + 314, + 910, + 314, + 1183, + 307, + 625, + 379, + 642, + 13, + 314, + 1392, + 572, + 670, + 1903, + 290, + 651, + 612, + 379, + 362, + 25, + 2231, + 13, + 314, + 1282, + 287, + 1262, + 616, + 13952, + 1994, + 11, + 2513, + 287, + 11, + 766, + 399, + 8535, + 2712, + 351, + 36062, + 287, + 262, + 5228, + 11, + 25737, + 3804, + 503, + 319, + 262, + 18507, + 11, + 290, + 16914, + 319, + 262, + 6891, + 3084, + 13, + 8989, + 2406, + 422, + 257, + 1641, + 47655, + 351, + 13230, + 11, + 314, + 760, + 644, + 16914, + 3073, + 588, + 13, + 314, + 836, + 470, + 760, + 703, + 881, + 340, + 373, + 11, + 475, + 314, + 714, + 423, + 23529, + 276, + 340, + 510, + 290, + 5901, + 616, + 18057, + 351, + 340, + 13, + 314, + 6810, + 19772, + 2024, + 8347, + 287, + 262, + 2166, + 2119, + 290, + 399, + 8535, + 373, + 287, + 3294, + 11685, + 286, + 8242, + 290, + 607, + 7374, + 15224, + 13, + 383, + 4894, + 373, + 572, + 13, + 383, + 2156, + 373, + 3863, + 2319, + 37, + 532, + 340, + 373, + 1542, + 2354, + 13, + 220, + 198, + 198, + 40, + 1718, + 399, + 8535, + 284, + 616, + 1097, + 11, + 290, + 1444, + 16679, + 329, + 281, + 22536, + 355, + 314, + 373, + 12008, + 25737, + 373, + 14904, + 2752, + 13, + 220, + 314, + 1422, + 470, + 765, + 284, + 10436, + 290, + 22601, + 503, + 399, + 8535, + 523, + 314, + 9658, + 287, + 262, + 1097, + 290, + 1309, + 607, + 711, + 319, + 616, + 3072, + 1566, + 262, + 22536, + 5284, + 13, + 3226, + 1781, + 1644, + 290, + 32084, + 3751, + 510, + 355, + 880, + 13, + 314, + 4893, + 262, + 3074, + 290, + 780, + 399, + 8535, + 338, + 9955, + 318, + 503, + 286, + 3240, + 1762, + 11, + 34020, + 14, + 44, + 4146, + 547, + 1444, + 13, + 1649, + 484, + 5284, + 484, + 547, + 5897, + 290, + 4692, + 11, + 1422, + 470, + 1107, + 1561, + 11, + 1718, + 399, + 8535, + 11, + 290, + 1297, + 502, + 284, + 467, + 1363, + 13, + 220, + 198, + 198, + 2025, + 1711, + 1568, + 314, + 651, + 1363, + 290, + 41668, + 32682, + 7893, + 502, + 644, + 314, + 1053, + 1760, + 13, + 314, + 4893, + 2279, + 284, + 683, + 290, + 477, + 339, + 550, + 373, + 8993, + 329, + 502, + 13, + 18626, + 262, + 2104, + 1641, + 1541, + 2993, + 290, + 547, + 28674, + 379, + 502, + 329, + 644, + 314, + 550, + 1760, + 13, + 18626, + 314, + 373, + 366, + 448, + 286, + 1627, + 290, + 8531, + 1, + 780, + 314, + 1444, + 16679, + 878, + 4379, + 611, + 673, + 373, + 1682, + 31245, + 6, + 278, + 780, + 340, + 2900, + 503, + 673, + 373, + 655, + 47583, + 503, + 422, + 262, + 16914, + 13, + 775, + 8350, + 329, + 2250, + 290, + 314, + 1364, + 290, + 3377, + 262, + 1755, + 379, + 616, + 1266, + 1545, + 338, + 2156, + 290, + 16896, + 477, + 1755, + 13, + 314, + 3521, + 470, + 5412, + 340, + 477, + 523, + 314, + 2900, + 616, + 3072, + 572, + 290, + 3088, + 284, + 8960, + 290, + 655, + 9480, + 866, + 13, + 2011, + 1266, + 1545, + 373, + 510, + 477, + 1755, + 351, + 502, + 11, + 5149, + 502, + 314, + 750, + 2147, + 2642, + 11, + 290, + 314, + 1101, + 8788, + 13, + 220, + 198, + 198, + 40, + 1210, + 616, + 3072, + 319, + 290, + 314, + 550, + 6135, + 13399, + 14, + 37348, + 1095, + 13, + 31515, + 11, + 34020, + 11, + 47551, + 11, + 41668, + 32682, + 11, + 290, + 511, + 7083, + 1641, + 1866, + 24630, + 502, + 13, + 1119, + 389, + 2282, + 314, + 20484, + 607, + 1204, + 11, + 20484, + 399, + 8535, + 338, + 1204, + 11, + 925, + 2279, + 517, + 8253, + 621, + 340, + 2622, + 284, + 307, + 11, + 925, + 340, + 1171, + 618, + 340, + 373, + 257, + 366, + 17989, + 14669, + 1600, + 290, + 20484, + 25737, + 338, + 8395, + 286, + 1683, + 1972, + 20750, + 393, + 1719, + 10804, + 286, + 607, + 1200, + 757, + 11, + 4844, + 286, + 606, + 1683, + 765, + 284, + 766, + 502, + 757, + 290, + 314, + 481, + 1239, + 766, + 616, + 41803, + 757, + 11, + 290, + 484, + 765, + 502, + 284, + 1414, + 329, + 25737, + 338, + 7356, + 6314, + 290, + 20889, + 502, + 329, + 262, + 32084, + 1339, + 290, + 7016, + 12616, + 13, + 198, + 198, + 40, + 716, + 635, + 783, + 2060, + 13, + 1406, + 319, + 1353, + 286, + 6078, + 616, + 1266, + 1545, + 286, + 838, + 812, + 357, + 69, + 666, + 32682, + 828, + 314, + 481, + 4425, + 616, + 7962, + 314, + 550, + 351, + 683, + 11, + 644, + 314, + 3177, + 616, + 1641, + 11, + 290, + 616, + 399, + 8535, + 13, + 198, + 198, + 40, + 4988, + 1254, + 12361, + 13, + 314, + 423, + 12361, + 9751, + 284, + 262, + 966, + 810, + 314, + 1101, + 7960, + 2130, + 318, + 1016, + 284, + 1282, + 651, + 366, + 260, + 18674, + 1, + 319, + 502, + 329, + 644, + 314, + 750, + 13, + 314, + 460, + 470, + 4483, + 13, + 314, + 423, + 2626, + 767, + 8059, + 422, + 340, + 13, + 314, + 1101, + 407, + 11029, + 329, + 7510, + 13, + 314, + 423, + 11668, + 739, + 616, + 2951, + 13, + 314, + 1053, + 550, + 807, + 50082, + 12, + 12545, + 287, + 734, + 2745, + 13, + 1629, + 717, + 314, + 2936, + 523, + 6563, + 287, + 616, + 2551, + 475, + 355, + 262, + 1528, + 467, + 416, + 314, + 1101, + 3612, + 3863, + 484, + 547, + 826, + 290, + 314, + 815, + 423, + 10667, + 319, + 607, + 878, + 4585, + 16679, + 290, + 852, + 5306, + 3019, + 992, + 13, + 314, + 836, + 470, + 1337, + 546, + 25737, + 7471, + 11, + 475, + 314, + 750, + 18344, + 257, + 642, + 614, + 1468, + 1200, + 1497, + 422, + 607, + 3397, + 290, + 314, + 1254, + 12361, + 546, + 340, + 13, + 314, + 760, + 2130, + 287, + 262, + 1641, + 481, + 1011, + 607, + 287, + 11, + 475, + 340, + 338, + 1239, + 588, + 852, + 351, + 534, + 3397, + 13, + 1375, + 481, + 1663, + 510, + 20315, + 278, + 502, + 329, + 340, + 290, + 477, + 314, + 1053, + 1683, + 1760, + 318, + 1842, + 607, + 355, + 616, + 898, + 13, + 220, + 198, + 198, + 22367, + 11, + 317, + 2043, + 32, + 30, + 4222, + 1037, + 502, + 13, + 383, + 14934, + 318, + 6600, + 502, + 6776, + 13, + 220, + 198, + 24361, + 25, + 1148, + 428, + 2642, + 30, + 198, + 33706, + 25, + 645, + ], + [ + 32, + 2043, + 32, + 329, + 4585, + 262, + 1644, + 14, + 34, + 3705, + 319, + 616, + 47551, + 30, + 930, + 19219, + 284, + 1949, + 284, + 787, + 428, + 355, + 1790, + 355, + 1744, + 981, + 1390, + 3307, + 2622, + 13, + 220, + 198, + 198, + 40, + 423, + 587, + 351, + 616, + 41668, + 32682, + 329, + 718, + 812, + 13, + 376, + 666, + 32682, + 468, + 281, + 4697, + 6621, + 11, + 356, + 1183, + 869, + 607, + 25737, + 11, + 508, + 318, + 2579, + 290, + 468, + 257, + 642, + 614, + 1468, + 1200, + 13, + 314, + 373, + 612, + 262, + 1110, + 25737, + 373, + 287, + 4827, + 290, + 14801, + 373, + 4642, + 11, + 673, + 318, + 616, + 41803, + 13, + 2399, + 2104, + 1641, + 468, + 6412, + 284, + 502, + 355, + 465, + 38074, + 494, + 1201, + 1110, + 352, + 13, + 314, + 716, + 407, + 2910, + 475, + 356, + 389, + 1641, + 11, + 673, + 3848, + 502, + 38074, + 494, + 290, + 356, + 423, + 3993, + 13801, + 11, + 26626, + 11864, + 11, + 3503, + 13, + 220, + 198, + 198, + 17, + 812, + 2084, + 25737, + 373, + 287, + 14321, + 422, + 2563, + 13230, + 13, + 21051, + 11, + 2356, + 25542, + 11, + 290, + 47482, + 897, + 547, + 607, + 1517, + 13, + 1375, + 550, + 257, + 5110, + 14608, + 290, + 262, + 1641, + 7723, + 1637, + 284, + 3758, + 607, + 284, + 14321, + 290, + 477, + 8389, + 257, + 7269, + 284, + 1011, + 1337, + 286, + 14801, + 13, + 383, + 5156, + 338, + 9955, + 11, + 25737, + 338, + 13850, + 11, + 468, + 257, + 47973, + 14, + 9979, + 2762, + 1693, + 290, + 373, + 503, + 286, + 3240, + 329, + 362, + 1933, + 523, + 339, + 2492, + 470, + 612, + 329, + 477, + 286, + 428, + 13, + 220, + 198, + 198, + 3347, + 10667, + 5223, + 503, + 706, + 513, + 1528, + 11, + 23630, + 673, + 373, + 366, + 38125, + 290, + 655, + 2622, + 257, + 3338, + 8399, + 1911, + 314, + 2298, + 607, + 510, + 11, + 1011, + 607, + 284, + 607, + 2156, + 11, + 290, + 673, + 3393, + 2925, + 284, + 7523, + 20349, + 290, + 4144, + 257, + 6099, + 13, + 314, + 836, + 470, + 892, + 20349, + 318, + 257, + 2563, + 290, + 716, + 845, + 386, + 12, + 66, + 1236, + 571, + 292, + 3584, + 314, + 836, + 470, + 7523, + 11, + 475, + 326, + 373, + 407, + 5035, + 6402, + 314, + 655, + 6497, + 607, + 510, + 422, + 14321, + 13, + 220, + 198, + 198, + 32, + 1285, + 1568, + 673, + 373, + 6294, + 329, + 3013, + 24707, + 287, + 262, + 12436, + 1539, + 819, + 5722, + 329, + 852, + 604, + 1933, + 2739, + 11, + 39398, + 607, + 1097, + 5059, + 981, + 1029, + 290, + 318, + 852, + 16334, + 329, + 720, + 1120, + 74, + 422, + 15228, + 278, + 656, + 257, + 2156, + 11, + 290, + 373, + 12165, + 503, + 286, + 376, + 666, + 32682, + 338, + 584, + 6621, + 338, + 2156, + 329, + 32012, + 262, + 14595, + 373, + 30601, + 510, + 290, + 2491, + 357, + 7091, + 373, + 1029, + 8, + 290, + 262, + 2104, + 34624, + 373, + 46432, + 1268, + 1961, + 422, + 1660, + 2465, + 780, + 8168, + 2073, + 1625, + 1363, + 329, + 807, + 2250, + 13, + 720, + 1238, + 11, + 830, + 286, + 2465, + 290, + 5875, + 5770, + 511, + 2156, + 5096, + 5017, + 340, + 13, + 220, + 198, + 198, + 2504, + 373, + 477, + 938, + 614, + 13, + 1119, + 1053, + 587, + 287, + 511, + 649, + 2156, + 319, + 511, + 898, + 329, + 546, + 718, + 1933, + 13, + 554, + 3389, + 673, + 1444, + 34020, + 290, + 531, + 511, + 8744, + 373, + 4423, + 572, + 780, + 673, + 1422, + 470, + 423, + 262, + 1637, + 780, + 41646, + 338, + 37751, + 1392, + 32621, + 510, + 290, + 1422, + 470, + 467, + 832, + 13, + 679, + 3432, + 511, + 2739, + 8744, + 9024, + 492, + 257, + 2472, + 286, + 720, + 4059, + 13, + 314, + 1807, + 340, + 373, + 13678, + 306, + 5789, + 475, + 4030, + 616, + 5422, + 4423, + 13, + 1439, + 468, + 587, + 5897, + 1201, + 13, + 220, + 198, + 198, + 7571, + 2745, + 2084, + 11, + 673, + 1965, + 502, + 284, + 8804, + 617, + 1637, + 284, + 651, + 38464, + 329, + 399, + 8535, + 13, + 3226, + 1781, + 314, + 1101, + 407, + 1016, + 284, + 1309, + 616, + 41803, + 393, + 6621, + 467, + 14720, + 11, + 645, + 2300, + 644, + 318, + 1016, + 319, + 4306, + 11, + 523, + 314, + 910, + 314, + 1183, + 307, + 625, + 379, + 642, + 13, + 314, + 1392, + 572, + 670, + 1903, + 290, + 651, + 612, + 379, + 362, + 25, + 2231, + 13, + 314, + 1282, + 287, + 1262, + 616, + 13952, + 1994, + 11, + 2513, + 287, + 11, + 766, + 399, + 8535, + 2712, + 351, + 36062, + 287, + 262, + 5228, + 11, + 25737, + 3804, + 503, + 319, + 262, + 18507, + 11, + 290, + 16914, + 319, + 262, + 6891, + 3084, + 13, + 8989, + 2406, + 422, + 257, + 1641, + 47655, + 351, + 13230, + 11, + 314, + 760, + 644, + 16914, + 3073, + 588, + 13, + 314, + 836, + 470, + 760, + 703, + 881, + 340, + 373, + 11, + 475, + 314, + 714, + 423, + 23529, + 276, + 340, + 510, + 290, + 5901, + 616, + 18057, + 351, + 340, + 13, + 314, + 6810, + 19772, + 2024, + 8347, + 287, + 262, + 2166, + 2119, + 290, + 399, + 8535, + 373, + 287, + 3294, + 11685, + 286, + 8242, + 290, + 607, + 7374, + 15224, + 13, + 383, + 4894, + 373, + 572, + 13, + 383, + 2156, + 373, + 3863, + 2319, + 37, + 532, + 340, + 373, + 1542, + 2354, + 13, + 220, + 198, + 198, + 40, + 1718, + 399, + 8535, + 284, + 616, + 1097, + 11, + 290, + 1444, + 16679, + 329, + 281, + 22536, + 355, + 314, + 373, + 12008, + 25737, + 373, + 14904, + 2752, + 13, + 220, + 314, + 1422, + 470, + 765, + 284, + 10436, + 290, + 22601, + 503, + 399, + 8535, + 523, + 314, + 9658, + 287, + 262, + 1097, + 290, + 1309, + 607, + 711, + 319, + 616, + 3072, + 1566, + 262, + 22536, + 5284, + 13, + 3226, + 1781, + 1644, + 290, + 32084, + 3751, + 510, + 355, + 880, + 13, + 314, + 4893, + 262, + 3074, + 290, + 780, + 399, + 8535, + 338, + 9955, + 318, + 503, + 286, + 3240, + 1762, + 11, + 34020, + 14, + 44, + 4146, + 547, + 1444, + 13, + 1649, + 484, + 5284, + 484, + 547, + 5897, + 290, + 4692, + 11, + 1422, + 470, + 1107, + 1561, + 11, + 1718, + 399, + 8535, + 11, + 290, + 1297, + 502, + 284, + 467, + 1363, + 13, + 220, + 198, + 198, + 2025, + 1711, + 1568, + 314, + 651, + 1363, + 290, + 41668, + 32682, + 7893, + 502, + 644, + 314, + 1053, + 1760, + 13, + 314, + 4893, + 2279, + 284, + 683, + 290, + 477, + 339, + 550, + 373, + 8993, + 329, + 502, + 13, + 18626, + 262, + 2104, + 1641, + 1541, + 2993, + 290, + 547, + 28674, + 379, + 502, + 329, + 644, + 314, + 550, + 1760, + 13, + 18626, + 314, + 373, + 366, + 448, + 286, + 1627, + 290, + 8531, + 1, + 780, + 314, + 1444, + 16679, + 878, + 4379, + 611, + 673, + 373, + 1682, + 31245, + 6, + 278, + 780, + 340, + 2900, + 503, + 673, + 373, + 655, + 47583, + 503, + 422, + 262, + 16914, + 13, + 775, + 8350, + 329, + 2250, + 290, + 314, + 1364, + 290, + 3377, + 262, + 1755, + 379, + 616, + 1266, + 1545, + 338, + 2156, + 290, + 16896, + 477, + 1755, + 13, + 314, + 3521, + 470, + 5412, + 340, + 477, + 523, + 314, + 2900, + 616, + 3072, + 572, + 290, + 3088, + 284, + 8960, + 290, + 655, + 9480, + 866, + 13, + 2011, + 1266, + 1545, + 373, + 510, + 477, + 1755, + 351, + 502, + 11, + 5149, + 502, + 314, + 750, + 2147, + 2642, + 11, + 290, + 314, + 1101, + 8788, + 13, + 220, + 198, + 198, + 40, + 1210, + 616, + 3072, + 319, + 290, + 314, + 550, + 6135, + 13399, + 14, + 37348, + 1095, + 13, + 31515, + 11, + 34020, + 11, + 47551, + 11, + 41668, + 32682, + 11, + 290, + 511, + 7083, + 1641, + 1866, + 24630, + 502, + 13, + 1119, + 389, + 2282, + 314, + 20484, + 607, + 1204, + 11, + 20484, + 399, + 8535, + 338, + 1204, + 11, + 925, + 2279, + 517, + 8253, + 621, + 340, + 2622, + 284, + 307, + 11, + 925, + 340, + 1171, + 618, + 340, + 373, + 257, + 366, + 17989, + 14669, + 1600, + 290, + 20484, + 25737, + 338, + 8395, + 286, + 1683, + 1972, + 20750, + 393, + 1719, + 10804, + 286, + 607, + 1200, + 757, + 11, + 4844, + 286, + 606, + 1683, + 765, + 284, + 766, + 502, + 757, + 290, + 314, + 481, + 1239, + 766, + 616, + 41803, + 757, + 11, + 290, + 484, + 765, + 502, + 284, + 1414, + 329, + 25737, + 338, + 7356, + 6314, + 290, + 20889, + 502, + 329, + 262, + 32084, + 1339, + 290, + 7016, + 12616, + 13, + 198, + 198, + 40, + 716, + 635, + 783, + 2060, + 13, + 1406, + 319, + 1353, + 286, + 6078, + 616, + 1266, + 1545, + 286, + 838, + 812, + 357, + 69, + 666, + 32682, + 828, + 314, + 481, + 4425, + 616, + 7962, + 314, + 550, + 351, + 683, + 11, + 644, + 314, + 3177, + 616, + 1641, + 11, + 290, + 616, + 399, + 8535, + 13, + 198, + 198, + 40, + 4988, + 1254, + 12361, + 13, + 314, + 423, + 12361, + 9751, + 284, + 262, + 966, + 810, + 314, + 1101, + 7960, + 2130, + 318, + 1016, + 284, + 1282, + 651, + 366, + 260, + 18674, + 1, + 319, + 502, + 329, + 644, + 314, + 750, + 13, + 314, + 460, + 470, + 4483, + 13, + 314, + 423, + 2626, + 767, + 8059, + 422, + 340, + 13, + 314, + 1101, + 407, + 11029, + 329, + 7510, + 13, + 314, + 423, + 11668, + 739, + 616, + 2951, + 13, + 314, + 1053, + 550, + 807, + 50082, + 12, + 12545, + 287, + 734, + 2745, + 13, + 1629, + 717, + 314, + 2936, + 523, + 6563, + 287, + 616, + 2551, + 475, + 355, + 262, + 1528, + 467, + 416, + 314, + 1101, + 3612, + 3863, + 484, + 547, + 826, + 290, + 314, + 815, + 423, + 10667, + 319, + 607, + 878, + 4585, + 16679, + 290, + 852, + 5306, + 3019, + 992, + 13, + 314, + 836, + 470, + 1337, + 546, + 25737, + 7471, + 11, + 475, + 314, + 750, + 18344, + 257, + 642, + 614, + 1468, + 1200, + 1497, + 422, + 607, + 3397, + 290, + 314, + 1254, + 12361, + 546, + 340, + 13, + 314, + 760, + 2130, + 287, + 262, + 1641, + 481, + 1011, + 607, + 287, + 11, + 475, + 340, + 338, + 1239, + 588, + 852, + 351, + 534, + 3397, + 13, + 1375, + 481, + 1663, + 510, + 20315, + 278, + 502, + 329, + 340, + 290, + 477, + 314, + 1053, + 1683, + 1760, + 318, + 1842, + 607, + 355, + 616, + 898, + 13, + 220, + 198, + 198, + 22367, + 11, + 317, + 2043, + 32, + 30, + 4222, + 1037, + 502, + 13, + 383, + 14934, + 318, + 6600, + 502, + 6776, + 13, + 220, + 198, + 24361, + 25, + 1148, + 428, + 2642, + 30, + 198, + 33706, + 25, + 3763, + ], +] def test_completion_openai_prompt(): @@ -28,39 +2685,50 @@ def test_completion_openai_prompt(): print(response) response_str = response["choices"][0]["text"] # print(response.choices[0]) - #print(response.choices[0].text) + # print(response.choices[0].text) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_openai_prompt() + def test_completion_openai_engine_and_model(): try: print("\n text 003 test\n") - litellm.set_verbose=True + litellm.set_verbose = True response = text_completion( - model="text-davinci-003", engine="anything", prompt="What's the weather in SF?", max_tokens=5 + model="text-davinci-003", + engine="anything", + prompt="What's the weather in SF?", + max_tokens=5, ) print(response) response_str = response["choices"][0]["text"] # print(response.choices[0]) - #print(response.choices[0].text) + # print(response.choices[0].text) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_openai_engine_and_model() + def test_completion_openai_engine(): try: print("\n text 003 test\n") - litellm.set_verbose=True + litellm.set_verbose = True response = text_completion( engine="text-davinci-003", prompt="What's the weather in SF?", max_tokens=5 ) print(response) response_str = response["choices"][0]["text"] # print(response.choices[0]) - #print(response.choices[0].text) + # print(response.choices[0].text) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_openai_engine() @@ -74,35 +2742,43 @@ def test_completion_chatgpt_prompt(): response_str = response["choices"][0]["text"] print("\n", response.choices) print("\n", response.choices[0]) - #print(response.choices[0].text) + # print(response.choices[0].text) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_chatgpt_prompt() def test_text_completion_basic(): try: print("\n test 003 with echo and logprobs \n") - litellm.set_verbose=False + litellm.set_verbose = False response = text_completion( - model="text-davinci-003", prompt="good morning", max_tokens=10, logprobs=10, echo=True + model="text-davinci-003", + prompt="good morning", + max_tokens=10, + logprobs=10, + echo=True, ) print(response) print(response.choices) print(response.choices[0]) - #print(response.choices[0].text) + # print(response.choices[0].text) response_str = response["choices"][0]["text"] except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_text_completion_basic() def test_completion_text_003_prompt_array(): try: - litellm.set_verbose=False + litellm.set_verbose = False response = text_completion( - model="text-davinci-003", - prompt=token_prompt, # token prompt is a 2d list + model="text-davinci-003", + prompt=token_prompt, # token prompt is a 2d list ) print("\n\n response") @@ -110,6 +2786,8 @@ def test_completion_text_003_prompt_array(): # response_str = response["choices"][0]["text"] except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_text_003_prompt_array() @@ -132,47 +2810,52 @@ def test_completion_text_003_prompt_array(): # pytest.fail(f"Error occurred: {e}") # test_text_completion_with_proxy() + ##### hugging face tests def test_completion_hf_prompt_array(): try: - litellm.set_verbose=True + litellm.set_verbose = True print("\n testing hf mistral\n") response = text_completion( - model="huggingface/mistralai/Mistral-7B-v0.1", - prompt=token_prompt, # token prompt is a 2d list, + model="huggingface/mistralai/Mistral-7B-v0.1", + prompt=token_prompt, # token prompt is a 2d list, max_tokens=0, temperature=0.0, - # echo=True, # hugging face inference api is currently raising errors for this, looks like they have a regression on their side + # echo=True, # hugging face inference api is currently raising errors for this, looks like they have a regression on their side ) print("\n\n response") print(response) print(response.choices) - assert(len(response.choices)==2) + assert len(response.choices) == 2 # response_str = response["choices"][0]["text"] except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_completion_hf_prompt_array() + def test_text_completion_stream(): try: response = text_completion( - model="huggingface/mistralai/Mistral-7B-v0.1", - prompt="good morning", - stream=True, - max_tokens=10, - ) + model="huggingface/mistralai/Mistral-7B-v0.1", + prompt="good morning", + stream=True, + max_tokens=10, + ) for chunk in response: print(f"chunk: {chunk}") except Exception as e: pytest.fail(f"GOT exception for HF In streaming{e}") + # test_text_completion_stream() # async def test_text_completion_async_stream(): # try: # response = await atext_completion( -# model="text-completion-openai/text-davinci-003", +# model="text-completion-openai/text-davinci-003", # prompt="good morning", # stream=True, # max_tokens=10, @@ -183,23 +2866,27 @@ def test_text_completion_stream(): # pytest.fail(f"GOT exception for HF In streaming{e}") # asyncio.run(test_text_completion_async_stream()) - + + def test_async_text_completion(): litellm.set_verbose = True - print('test_async_text_completion') + print("test_async_text_completion") + async def test_get_response(): try: response = await litellm.atext_completion( - model="gpt-3.5-turbo-instruct", + model="gpt-3.5-turbo-instruct", prompt="good morning", stream=False, - max_tokens=10 + max_tokens=10, ) print(f"response: {response}") - except litellm.Timeout as e: + except litellm.Timeout as e: print(e) - except Exception as e: + except Exception as e: print(e) asyncio.run(test_get_response()) -test_async_text_completion() \ No newline at end of file + + +test_async_text_completion() diff --git a/litellm/tests/test_timeout.py b/litellm/tests/test_timeout.py index 64c30e51e..68bbfb0aa 100644 --- a/litellm/tests/test_timeout.py +++ b/litellm/tests/test_timeout.py @@ -15,49 +15,50 @@ import pytest def test_timeout(): # this Will Raise a timeout - litellm.set_verbose=False + litellm.set_verbose = False try: response = litellm.completion( model="gpt-3.5-turbo", timeout=0.01, - messages=[ - { - "role": "user", - "content": "hello, write a 20 pg essay" - } - ] + messages=[{"role": "user", "content": "hello, write a 20 pg essay"}], ) except openai.APITimeoutError as e: - print("Passed: Raised correct exception. Got openai.APITimeoutError\nGood Job", e) + print( + "Passed: Raised correct exception. Got openai.APITimeoutError\nGood Job", e + ) print(type(e)) pass except Exception as e: - pytest.fail(f"Did not raise error `openai.APITimeoutError`. Instead raised error type: {type(e)}, Error: {e}") -# test_timeout() + pytest.fail( + f"Did not raise error `openai.APITimeoutError`. Instead raised error type: {type(e)}, Error: {e}" + ) +# test_timeout() + def test_timeout_streaming(): # this Will Raise a timeout - litellm.set_verbose=False + litellm.set_verbose = False try: response = litellm.completion( model="gpt-3.5-turbo", - messages=[ - { - "role": "user", - "content": "hello, write a 20 pg essay" - } - ], + messages=[{"role": "user", "content": "hello, write a 20 pg essay"}], timeout=0.0001, stream=True, ) for chunk in response: print(chunk) except openai.APITimeoutError as e: - print("Passed: Raised correct exception. Got openai.APITimeoutError\nGood Job", e) + print( + "Passed: Raised correct exception. Got openai.APITimeoutError\nGood Job", e + ) print(type(e)) pass except Exception as e: - pytest.fail(f"Did not raise error `openai.APITimeoutError`. Instead raised error type: {type(e)}, Error: {e}") -test_timeout_streaming() \ No newline at end of file + pytest.fail( + f"Did not raise error `openai.APITimeoutError`. Instead raised error type: {type(e)}, Error: {e}" + ) + + +test_timeout_streaming() diff --git a/litellm/tests/test_together_ai.py b/litellm/tests/test_together_ai.py index 361ca8ee7..d4d5f968a 100644 --- a/litellm/tests/test_together_ai.py +++ b/litellm/tests/test_together_ai.py @@ -44,7 +44,7 @@ # model=model, stop=stop_words, max_tokens=512): # print(token, end="") - + # ### litellm # import os @@ -61,4 +61,4 @@ # res = completion(model="together_ai/togethercomputer/CodeLlama-13b-Instruct", # messages=sample_message, stream=False, max_tokens=1000) -# print(list(res)) \ No newline at end of file +# print(list(res)) diff --git a/litellm/tests/test_token_counter.py b/litellm/tests/test_token_counter.py index b30e1126d..5903f46b5 100644 --- a/litellm/tests/test_token_counter.py +++ b/litellm/tests/test_token_counter.py @@ -4,6 +4,7 @@ import sys, os import traceback import pytest + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path @@ -11,30 +12,50 @@ import time from litellm import token_counter, encode, decode -def test_token_counter_normal_plus_function_calling(): - try: +def test_token_counter_normal_plus_function_calling(): + try: messages = [ - {'role': 'system', 'content': "System prompt"}, - {'role': 'user', 'content': 'content1'}, - {'role': 'assistant', 'content': 'content2'}, - {'role': 'user', 'content': 'conten3'}, - {'role': 'assistant', 'content': None, 'tool_calls': [{'id': 'call_E0lOb1h6qtmflUyok4L06TgY', 'function': {'arguments': '{"query":"search query","domain":"google.ca","gl":"ca","hl":"en"}', 'name': 'SearchInternet'}, 'type': 'function'}]}, - {'tool_call_id': 'call_E0lOb1h6qtmflUyok4L06TgY', 'role': 'tool', 'name': 'SearchInternet', 'content': 'tool content'} + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": "content1"}, + {"role": "assistant", "content": "content2"}, + {"role": "user", "content": "conten3"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_E0lOb1h6qtmflUyok4L06TgY", + "function": { + "arguments": '{"query":"search query","domain":"google.ca","gl":"ca","hl":"en"}', + "name": "SearchInternet", + }, + "type": "function", + } + ], + }, + { + "tool_call_id": "call_E0lOb1h6qtmflUyok4L06TgY", + "role": "tool", + "name": "SearchInternet", + "content": "tool content", + }, ] tokens = token_counter(model="gpt-3.5-turbo", messages=messages) print(f"tokens: {tokens}") - except Exception as e: + except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") -test_token_counter_normal_plus_function_calling() + +test_token_counter_normal_plus_function_calling() + def test_tokenizers(): - try: - ### test the openai, claude, cohere and llama2 tokenizers. + try: + ### test the openai, claude, cohere and llama2 tokenizers. ### The tokenizer value should be different for all sample_text = "Hellö World, this is my input string!" - # openai tokenizer + # openai tokenizer openai_tokens = token_counter(model="gpt-3.5-turbo", text=sample_text) # claude tokenizer @@ -43,36 +64,44 @@ def test_tokenizers(): # cohere tokenizer cohere_tokens = token_counter(model="command-nightly", text=sample_text) - # llama2 tokenizer - llama2_tokens = token_counter(model="meta-llama/Llama-2-7b-chat", text=sample_text) + # llama2 tokenizer + llama2_tokens = token_counter( + model="meta-llama/Llama-2-7b-chat", text=sample_text + ) - print(f"openai tokens: {openai_tokens}; claude tokens: {claude_tokens}; cohere tokens: {cohere_tokens}; llama2 tokens: {llama2_tokens}") + print( + f"openai tokens: {openai_tokens}; claude tokens: {claude_tokens}; cohere tokens: {cohere_tokens}; llama2 tokens: {llama2_tokens}" + ) # assert that all token values are different - assert openai_tokens != cohere_tokens != llama2_tokens, "Token values are not different." - + assert ( + openai_tokens != cohere_tokens != llama2_tokens + ), "Token values are not different." + print("test tokenizer: It worked!") - except Exception as e: - pytest.fail(f'An exception occured: {e}') + except Exception as e: + pytest.fail(f"An exception occured: {e}") + # test_tokenizers() -def test_encoding_and_decoding(): - try: + +def test_encoding_and_decoding(): + try: sample_text = "Hellö World, this is my input string!" # openai encoding + decoding openai_tokens = encode(model="gpt-3.5-turbo", text=sample_text) openai_text = decode(model="gpt-3.5-turbo", tokens=openai_tokens) assert openai_text == sample_text - - # claude encoding + decoding + + # claude encoding + decoding claude_tokens = encode(model="claude-instant-1", text=sample_text) claude_text = decode(model="claude-instant-1", tokens=claude_tokens.ids) assert claude_text == sample_text - # cohere encoding + decoding + # cohere encoding + decoding cohere_tokens = encode(model="command-nightly", text=sample_text) cohere_text = decode(model="command-nightly", tokens=cohere_tokens.ids) @@ -80,10 +109,13 @@ def test_encoding_and_decoding(): # llama2 encoding + decoding llama2_tokens = encode(model="meta-llama/Llama-2-7b-chat", text=sample_text) - llama2_text = decode(model="meta-llama/Llama-2-7b-chat", tokens=llama2_tokens.ids) + llama2_text = decode( + model="meta-llama/Llama-2-7b-chat", tokens=llama2_tokens.ids + ) assert llama2_text == sample_text - except Exception as e: - pytest.fail(f'An exception occured: {e}') + except Exception as e: + pytest.fail(f"An exception occured: {e}") -# test_encoding_and_decoding() \ No newline at end of file + +# test_encoding_and_decoding() diff --git a/litellm/tests/test_traceloop.py b/litellm/tests/test_traceloop.py index c03fdcd43..405a8a357 100644 --- a/litellm/tests/test_traceloop.py +++ b/litellm/tests/test_traceloop.py @@ -1,4 +1,4 @@ -# Commented out for now - since traceloop break ci/cd +# Commented out for now - since traceloop break ci/cd # import sys # import os # import io, asyncio @@ -15,8 +15,8 @@ # Traceloop.init(app_name="test-litellm", disable_batch=True) -# def test_traceloop_logging(): -# try: +# def test_traceloop_logging(): +# try: # litellm.set_verbose = True # response = litellm.completion( # model="gpt-3.5-turbo", @@ -26,13 +26,13 @@ # timeout=5, # ) # print(f"response: {response}") -# except Exception as e: +# except Exception as e: # pytest.fail(f"An exception occurred - {e}") # # test_traceloop_logging() -# # def test_traceloop_logging_async(): -# # try: +# # def test_traceloop_logging_async(): +# # try: # # litellm.set_verbose = True # # async def test_acompletion(): # # return await litellm.acompletion( @@ -44,6 +44,6 @@ # # ) # # response = asyncio.run(test_acompletion()) # # print(f"response: {response}") -# # except Exception as e: +# # except Exception as e: # # pytest.fail(f"An exception occurred - {e}") # # test_traceloop_logging_async() diff --git a/litellm/tests/test_utils.py b/litellm/tests/test_utils.py index 394682fd1..9764d5e5a 100644 --- a/litellm/tests/test_utils.py +++ b/litellm/tests/test_utils.py @@ -10,126 +10,216 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest import litellm -from litellm.utils import trim_messages, get_token_count, get_valid_models, check_valid_key, validate_environment, function_to_dict, token_counter +from litellm.utils import ( + trim_messages, + get_token_count, + get_valid_models, + check_valid_key, + validate_environment, + function_to_dict, + token_counter, +) # Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils' + # Test 1: Check trimming of normal message def test_basic_trimming(): - messages = [{"role": "user", "content": "This is a long message that definitely exceeds the token limit."}] + messages = [ + { + "role": "user", + "content": "This is a long message that definitely exceeds the token limit.", + } + ] trimmed_messages = trim_messages(messages, model="claude-2", max_tokens=8) print("trimmed messages") print(trimmed_messages) # print(get_token_count(messages=trimmed_messages, model="claude-2")) assert (get_token_count(messages=trimmed_messages, model="claude-2")) <= 8 + + # test_basic_trimming() + def test_basic_trimming_no_max_tokens_specified(): - messages = [{"role": "user", "content": "This is a long message that is definitely under the token limit."}] + messages = [ + { + "role": "user", + "content": "This is a long message that is definitely under the token limit.", + } + ] trimmed_messages = trim_messages(messages, model="gpt-4") print("trimmed messages for gpt-4") print(trimmed_messages) # print(get_token_count(messages=trimmed_messages, model="claude-2")) - assert (get_token_count(messages=trimmed_messages, model="gpt-4")) <= litellm.model_cost['gpt-4']['max_tokens'] + assert ( + get_token_count(messages=trimmed_messages, model="gpt-4") + ) <= litellm.model_cost["gpt-4"]["max_tokens"] + + # test_basic_trimming_no_max_tokens_specified() + def test_multiple_messages_trimming(): messages = [ - {"role": "user", "content": "This is a long message that will exceed the token limit."}, - {"role": "user", "content": "This is another long message that will also exceed the limit."} + { + "role": "user", + "content": "This is a long message that will exceed the token limit.", + }, + { + "role": "user", + "content": "This is another long message that will also exceed the limit.", + }, ] - trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=20) + trimmed_messages = trim_messages( + messages=messages, model="gpt-3.5-turbo", max_tokens=20 + ) # print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) - assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20 + assert (get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20 + + # test_multiple_messages_trimming() + def test_multiple_messages_no_trimming(): messages = [ - {"role": "user", "content": "This is a long message that will exceed the token limit."}, - {"role": "user", "content": "This is another long message that will also exceed the limit."} + { + "role": "user", + "content": "This is a long message that will exceed the token limit.", + }, + { + "role": "user", + "content": "This is another long message that will also exceed the limit.", + }, ] - trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=100) + trimmed_messages = trim_messages( + messages=messages, model="gpt-3.5-turbo", max_tokens=100 + ) print("Trimmed messages") print(trimmed_messages) - assert(messages==trimmed_messages) + assert messages == trimmed_messages + # test_multiple_messages_no_trimming() def test_large_trimming_multiple_messages(): - messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}] + messages = [ + {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, + {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, + {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, + {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, + {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, + ] trimmed_messages = trim_messages(messages, max_tokens=20, model="gpt-4-0613") print("trimmed messages") print(trimmed_messages) - assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 20 + assert (get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 20 + + # test_large_trimming() + def test_large_trimming_single_message(): - messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}] + messages = [ + {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."} + ] trimmed_messages = trim_messages(messages, max_tokens=5, model="gpt-4-0613") - assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 5 - assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) > 0 + assert (get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 5 + assert (get_token_count(messages=trimmed_messages, model="gpt-4-0613")) > 0 def test_trimming_with_system_message_within_max_tokens(): # This message is 33 tokens long - messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}] - trimmed_messages = trim_messages(messages, max_tokens=30, model="gpt-4-0613") # The system message should fit within the token limit + messages = [ + {"role": "system", "content": "This is a short system message"}, + { + "role": "user", + "content": "This is a medium normal message, let's say litellm is awesome.", + }, + ] + trimmed_messages = trim_messages( + messages, max_tokens=30, model="gpt-4-0613" + ) # The system message should fit within the token limit assert len(trimmed_messages) == 2 assert trimmed_messages[0]["content"] == "This is a short system message" def test_trimming_with_system_message_exceeding_max_tokens(): # This message is 33 tokens long. The system message is 13 tokens long. - messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}] + messages = [ + {"role": "system", "content": "This is a short system message"}, + { + "role": "user", + "content": "This is a medium normal message, let's say litellm is awesome.", + }, + ] trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613") assert len(trimmed_messages) == 1 + def test_trimming_should_not_change_original_messages(): - messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}] + messages = [ + {"role": "system", "content": "This is a short system message"}, + { + "role": "user", + "content": "This is a medium normal message, let's say litellm is awesome.", + }, + ] messages_copy = copy.deepcopy(messages) trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613") - assert(messages==messages_copy) + assert messages == messages_copy + def test_get_valid_models(): old_environ = os.environ - os.environ = {'OPENAI_API_KEY': 'temp'} # mock set only openai key in environ + os.environ = {"OPENAI_API_KEY": "temp"} # mock set only openai key in environ valid_models = get_valid_models() print(valid_models) # list of openai supported llms on litellm - expected_models = litellm.open_ai_chat_completion_models + litellm.open_ai_text_completion_models - - assert(valid_models == expected_models) + expected_models = ( + litellm.open_ai_chat_completion_models + litellm.open_ai_text_completion_models + ) + + assert valid_models == expected_models # reset replicate env key os.environ = old_environ + # test_get_valid_models() + def test_bad_key(): key = "bad-key" response = check_valid_key(model="gpt-3.5-turbo", api_key=key) print(response, key) - assert(response == False) + assert response == False + def test_good_key(): - key = os.environ['OPENAI_API_KEY'] + key = os.environ["OPENAI_API_KEY"] response = check_valid_key(model="gpt-3.5-turbo", api_key=key) - assert(response == True) + assert response == True + + +# test validate environment -# test validate environment def test_validate_environment_empty_model(): api_key = validate_environment() if api_key is None: - raise Exception() + raise Exception() + # test_validate_environment_empty_model() + def test_function_to_dict(): print("testing function to dict for get current weather") + def get_current_weather(location: str, unit: str): """Get the current weather in a given location @@ -147,90 +237,83 @@ def test_function_to_dict(): """ if location == "Boston, MA": return "The weather is 12F" + function_json = litellm.utils.function_to_dict(get_current_weather) print(function_json) expected_output = { - 'name': 'get_current_weather', - 'description': 'Get the current weather in a given location', - 'parameters': { - 'type': 'object', - 'properties': { - 'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, - 'unit': {'type': 'string', 'description': 'Temperature unit', 'enum': "['fahrenheit', 'celsius']"} - }, - 'required': ['location', 'unit'] - } + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "description": "Temperature unit", + "enum": "['fahrenheit', 'celsius']", + }, + }, + "required": ["location", "unit"], + }, } print(expected_output) - - assert function_json['name'] == expected_output["name"] + + assert function_json["name"] == expected_output["name"] assert function_json["description"] == expected_output["description"] assert function_json["parameters"]["type"] == expected_output["parameters"]["type"] - assert function_json["parameters"]["properties"]["location"] == expected_output["parameters"]["properties"]["location"] + assert ( + function_json["parameters"]["properties"]["location"] + == expected_output["parameters"]["properties"]["location"] + ) # the enum can change it can be - which is why we don't assert on unit # {'type': 'string', 'description': 'Temperature unit', 'enum': "['fahrenheit', 'celsius']"} # {'type': 'string', 'description': 'Temperature unit', 'enum': "['celsius', 'fahrenheit']"} - assert function_json["parameters"]["required"] == expected_output["parameters"]["required"] + assert ( + function_json["parameters"]["required"] + == expected_output["parameters"]["required"] + ) print("passed") + + # test_function_to_dict() def test_token_counter(): try: - messages = [ - { - "role": "user", - "content": "hi how are you what time is it" - } - ] - tokens = token_counter( - model = "gpt-3.5-turbo", - messages=messages - ) + messages = [{"role": "user", "content": "hi how are you what time is it"}] + tokens = token_counter(model="gpt-3.5-turbo", messages=messages) print("gpt-35-turbo") print(tokens) assert tokens > 0 - tokens = token_counter( - model = "claude-2", - messages=messages - ) + tokens = token_counter(model="claude-2", messages=messages) print("claude-2") print(tokens) assert tokens > 0 - tokens = token_counter( - model = "palm/chat-bison", - messages=messages - ) + tokens = token_counter(model="palm/chat-bison", messages=messages) print("palm/chat-bison") print(tokens) assert tokens > 0 - tokens = token_counter( - model = "ollama/llama2", - messages=messages - ) + tokens = token_counter(model="ollama/llama2", messages=messages) print("ollama/llama2") print(tokens) assert tokens > 0 - tokens = token_counter( - model = "anthropic.claude-instant-v1", - messages=messages - ) + tokens = token_counter(model="anthropic.claude-instant-v1", messages=messages) print("anthropic.claude-instant-v1") print(tokens) assert tokens > 0 except Exception as e: pytest.fail(f"Error occurred: {e}") + + test_token_counter() - - - - - diff --git a/litellm/tests/test_validate_environment.py b/litellm/tests/test_validate_environment.py index 2b60c7fe8..dce61b3ab 100644 --- a/litellm/tests/test_validate_environment.py +++ b/litellm/tests/test_validate_environment.py @@ -10,4 +10,4 @@ sys.path.insert( import time import litellm -print(litellm.validate_environment("openai/gpt-3.5-turbo")) \ No newline at end of file +print(litellm.validate_environment("openai/gpt-3.5-turbo")) diff --git a/litellm/tests/test_wandb.py b/litellm/tests/test_wandb.py index fe10b3e61..d31310fa6 100644 --- a/litellm/tests/test_wandb.py +++ b/litellm/tests/test_wandb.py @@ -1,58 +1,72 @@ import sys import os import io, asyncio + # import logging # logging.basicConfig(level=logging.DEBUG) -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) from litellm import completion import litellm + litellm.num_retries = 3 litellm.success_callback = ["wandb"] import time import pytest -def test_wandb_logging_async(): - try: + +def test_wandb_logging_async(): + try: litellm.set_verbose = False + async def _test_langfuse(): from litellm import Router - model_list = [{ # list of model deployments - "model_name": "gpt-3.5-turbo", - "litellm_params": { # params for litellm completion/embedding call - "model": "gpt-3.5-turbo", - "api_key": os.getenv("OPENAI_API_KEY"), + + model_list = [ + { # list of model deployments + "model_name": "gpt-3.5-turbo", + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, } - }] + ] router = Router(model_list=model_list) # openai.ChatCompletion.create replacement - response = await router.acompletion(model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "this is a test with litellm router ?"}]) + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[ + {"role": "user", "content": "this is a test with litellm router ?"} + ], + ) print(response) + response = asyncio.run(_test_langfuse()) print(f"response: {response}") - except litellm.Timeout as e: + except litellm.Timeout as e: pass - except Exception as e: + except Exception as e: pass + + test_wandb_logging_async() + def test_wandb_logging(): try: - response = completion(model="claude-instant-1.2", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm claude" - }], - max_tokens=10, - temperature=0.2 - ) + response = completion( + model="claude-instant-1.2", + messages=[{"role": "user", "content": "Hi 👋 - i'm claude"}], + max_tokens=10, + temperature=0.2, + ) print(response) - except litellm.Timeout as e: + except litellm.Timeout as e: pass except Exception as e: print(e) + # test_wandb_logging() diff --git a/litellm/timeout.py b/litellm/timeout.py index 9007c309e..b4446edf0 100644 --- a/litellm/timeout.py +++ b/litellm/timeout.py @@ -56,8 +56,8 @@ def timeout(timeout_duration: float = 0.0, exception_to_raise=Timeout): model = args[0] if len(args) > 0 else kwargs["model"] raise exception_to_raise( f"A timeout error occurred. The function call took longer than {local_timeout_duration} second(s).", - model=model, # [TODO]: replace with logic for parsing out llm provider from model name - llm_provider="openai" + model=model, # [TODO]: replace with logic for parsing out llm provider from model name + llm_provider="openai", ) thread.stop_loop() return result @@ -78,8 +78,8 @@ def timeout(timeout_duration: float = 0.0, exception_to_raise=Timeout): model = args[0] if len(args) > 0 else kwargs["model"] raise exception_to_raise( f"A timeout error occurred. The function call took longer than {local_timeout_duration} second(s).", - model=model, # [TODO]: replace with logic for parsing out llm provider from model name - llm_provider="openai" + model=model, # [TODO]: replace with logic for parsing out llm provider from model name + llm_provider="openai", ) if iscoroutinefunction(func): diff --git a/litellm/utils.py b/litellm/utils.py index 94fc46039..b6afdeb2d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -27,6 +27,7 @@ from dataclasses import ( dataclass, field, ) # for storing API inputs, outputs, and metadata + encoding = tiktoken.get_encoding("cl100k_base") import importlib.metadata from .integrations.traceloop import TraceloopLogger @@ -56,16 +57,17 @@ from .exceptions import ( APIConnectionError, APIError, BudgetExceededError, - UnprocessableEntityError + UnprocessableEntityError, ) from typing import cast, List, Dict, Union, Optional, Literal from .caching import Cache from concurrent.futures import ThreadPoolExecutor + ####### ENVIRONMENT VARIABLES #################### # Adjust to your specific application needs / system capabilities. -MAX_THREADS = 100 +MAX_THREADS = 100 -# Create a ThreadPoolExecutor +# Create a ThreadPoolExecutor executor = ThreadPoolExecutor(max_workers=MAX_THREADS) dotenv.load_dotenv() # Loading env variables using dotenv sentry_sdk_instance = None @@ -111,6 +113,7 @@ last_fetched_at_keys = None # 'usage': {'prompt_tokens': 18, 'completion_tokens': 23, 'total_tokens': 41} # } + class UnsupportedParamsError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -122,64 +125,81 @@ class UnsupportedParamsError(Exception): ) # Call the base class constructor with the parameters it needs -def _generate_id(): # private helper function - return 'chatcmpl-' + str(uuid.uuid4()) +def _generate_id(): # private helper function + return "chatcmpl-" + str(uuid.uuid4()) -def map_finish_reason(finish_reason: str): # openai supports 5 stop sequences - 'stop', 'length', 'function_call', 'content_filter', 'null' + +def map_finish_reason( + finish_reason: str, +): # openai supports 5 stop sequences - 'stop', 'length', 'function_call', 'content_filter', 'null' # anthropic mapping if finish_reason == "stop_sequence": return "stop" # cohere mapping - https://docs.cohere.com/reference/generate - elif finish_reason == "COMPLETE": + elif finish_reason == "COMPLETE": return "stop" - elif finish_reason == "MAX_TOKENS": # cohere + vertex ai + elif finish_reason == "MAX_TOKENS": # cohere + vertex ai return "length" - elif finish_reason == "ERROR_TOXIC": + elif finish_reason == "ERROR_TOXIC": return "content_filter" - elif finish_reason == "ERROR": # openai currently doesn't support an 'error' finish reason + elif ( + finish_reason == "ERROR" + ): # openai currently doesn't support an 'error' finish reason return "stop" # huggingface mapping https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate_stream elif finish_reason == "eos_token" or finish_reason == "stop_sequence": return "stop" - elif finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP": # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',] + elif ( + finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP" + ): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',] return "stop" - elif finish_reason == "SAFETY": # vertex ai + elif finish_reason == "SAFETY": # vertex ai return "content_filter" return finish_reason + class FunctionCall(OpenAIObject): arguments: str name: str + class Function(OpenAIObject): arguments: str name: str + class ChatCompletionMessageToolCall(OpenAIObject): id: str function: Function type: str + class Message(OpenAIObject): - def __init__(self, content="default", role="assistant", logprobs=None, function_call=None, tool_calls=None, **params): + def __init__( + self, + content="default", + role="assistant", + logprobs=None, + function_call=None, + tool_calls=None, + **params, + ): super(Message, self).__init__(**params) self.content = content self.role = role - if function_call is not None: + if function_call is not None: self.function_call = FunctionCall(**function_call) if tool_calls is not None: self.tool_calls = [] for tool_call in tool_calls: - self.tool_calls.append( - ChatCompletionMessageToolCall(**tool_call) - ) + self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call)) if logprobs is not None: - self._logprobs = logprobs + self._logprobs = logprobs def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -190,7 +210,7 @@ class Message(OpenAIObject): def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() @@ -201,7 +221,7 @@ class Delta(OpenAIObject): super(Delta, self).__init__(**params) self.content = content self.role = role - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -209,7 +229,7 @@ class Delta(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -222,13 +242,15 @@ class Delta(OpenAIObject): class Choices(OpenAIObject): def __init__(self, finish_reason=None, index=0, message=None, **params): super(Choices, self).__init__(**params) - self.finish_reason = map_finish_reason(finish_reason) or "stop" # set finish_reason for all responses + self.finish_reason = ( + map_finish_reason(finish_reason) or "stop" + ) # set finish_reason for all responses self.index = index if message is None: self.message = Message(content=None) else: self.message = message - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -236,7 +258,7 @@ class Choices(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -245,8 +267,11 @@ class Choices(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) + class Usage(OpenAIObject): - def __init__(self, prompt_tokens=None, completion_tokens=None, total_tokens=None, **params): + def __init__( + self, prompt_tokens=None, completion_tokens=None, total_tokens=None, **params + ): super(Usage, self).__init__(**params) if prompt_tokens: self.prompt_tokens = prompt_tokens @@ -254,15 +279,15 @@ class Usage(OpenAIObject): self.completion_tokens = completion_tokens if total_tokens: self.total_tokens = total_tokens - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -271,8 +296,11 @@ class Usage(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) + class StreamingChoices(OpenAIObject): - def __init__(self, finish_reason=None, index=0, delta: Optional[Delta]=None, **params): + def __init__( + self, finish_reason=None, index=0, delta: Optional[Delta] = None, **params + ): super(StreamingChoices, self).__init__(**params) if finish_reason: self.finish_reason = finish_reason @@ -283,15 +311,15 @@ class StreamingChoices(OpenAIObject): self.delta = delta else: self.delta = Delta() - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -300,7 +328,8 @@ class StreamingChoices(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) -class ModelResponse(OpenAIObject): + +class ModelResponse(OpenAIObject): id: str """A unique identifier for the completion.""" @@ -328,7 +357,20 @@ class ModelResponse(OpenAIObject): _hidden_params: dict = {} - def __init__(self, id=None, choices=None, created=None, model=None, object=None, system_fingerprint=None, usage=None, stream=False, response_ms=None, hidden_params=None, **params): + def __init__( + self, + id=None, + choices=None, + created=None, + model=None, + object=None, + system_fingerprint=None, + usage=None, + stream=False, + response_ms=None, + hidden_params=None, + **params, + ): if stream: object = "chat.completion.chunk" choices = [StreamingChoices()] @@ -353,16 +395,25 @@ class ModelResponse(OpenAIObject): usage = Usage() if hidden_params: self._hidden_params = hidden_params - super().__init__(id=id, choices=choices, created=created, model=model, object=object, system_fingerprint=system_fingerprint, usage=usage, **params) - + super().__init__( + id=id, + choices=choices, + created=created, + model=model, + object=object, + system_fingerprint=system_fingerprint, + usage=usage, + **params, + ) + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -370,14 +421,15 @@ class ModelResponse(OpenAIObject): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) - + def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() + class Embedding(OpenAIObject): embedding: list = [] index: int @@ -386,7 +438,7 @@ class Embedding(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -395,6 +447,7 @@ class Embedding(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) + class EmbeddingResponse(OpenAIObject): model: Optional[str] = None """The model used for embedding.""" @@ -408,17 +461,19 @@ class EmbeddingResponse(OpenAIObject): usage: Optional[Usage] = None """Usage statistics for the embedding request.""" - def __init__(self, model=None, usage=None, stream=False, response_ms=None, data=None): + def __init__( + self, model=None, usage=None, stream=False, response_ms=None, data=None + ): object = "list" if response_ms: _response_ms = response_ms else: _response_ms = None - if data: + if data: data = data - else: + else: data = None - + if usage: usage = usage else: @@ -430,11 +485,11 @@ class EmbeddingResponse(OpenAIObject): def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -442,14 +497,15 @@ class EmbeddingResponse(OpenAIObject): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) - + def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() + class TextChoices(OpenAIObject): def __init__(self, finish_reason=None, index=0, text=None, logprobs=None, **params): super(TextChoices, self).__init__(**params) @@ -466,7 +522,7 @@ class TextChoices(OpenAIObject): self.logprobs = [] else: self.logprobs = logprobs - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -474,7 +530,7 @@ class TextChoices(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -483,6 +539,7 @@ class TextChoices(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) + class TextCompletionResponse(OpenAIObject): """ { @@ -501,7 +558,18 @@ class TextCompletionResponse(OpenAIObject): "usage": response["usage"] } """ - def __init__(self, id=None, choices=None, created=None, model=None, usage=None, stream=False, response_ms=None, **params): + + def __init__( + self, + id=None, + choices=None, + created=None, + model=None, + usage=None, + stream=False, + response_ms=None, + **params, + ): super(TextCompletionResponse, self).__init__(**params) if stream: self.object = "text_completion.chunk" @@ -526,9 +594,10 @@ class TextCompletionResponse(OpenAIObject): self.usage = usage else: self.usage = Usage() - self._hidden_params = {} # used in case users want to access the original model response + self._hidden_params = ( + {} + ) # used in case users want to access the original model response - def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -536,7 +605,7 @@ class TextCompletionResponse(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -548,7 +617,7 @@ class TextCompletionResponse(OpenAIObject): class ImageResponse(OpenAIObject): created: Optional[int] = None - + data: Optional[list] = None usage: Optional[dict] = None @@ -558,28 +627,27 @@ class ImageResponse(OpenAIObject): _response_ms = response_ms else: _response_ms = None - if data: + if data: data = data - else: + else: data = None - + if created: created = created else: created = None - + super().__init__(data=data, created=created) self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} - def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -587,52 +655,69 @@ class ImageResponse(OpenAIObject): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) - + def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() - + + ############################################################ def print_verbose(print_statement): try: if litellm.set_verbose: - print(print_statement) # noqa + print(print_statement) # noqa except: pass + ####### LOGGING ################### from enum import Enum + class CallTypes(Enum): - embedding = 'embedding' - completion = 'completion' - acompletion = 'acompletion' - aembedding = 'aembedding' - image_generation = 'image_generation' - aimage_generation = 'aimage_generation' + embedding = "embedding" + completion = "completion" + acompletion = "acompletion" + aembedding = "aembedding" + image_generation = "image_generation" + aimage_generation = "aimage_generation" + # Logging function -> log the exact model details + what's being sent | Non-Blocking class Logging: global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, capture_exception, add_breadcrumb, llmonitorLogger - def __init__(self, model, messages, stream, call_type, start_time, litellm_call_id, function_id): + def __init__( + self, + model, + messages, + stream, + call_type, + start_time, + litellm_call_id, + function_id, + ): if call_type not in [item.value for item in CallTypes]: allowed_values = ", ".join([item.value for item in CallTypes]) - raise ValueError(f"Invalid call_type {call_type}. Allowed values: {allowed_values}") + raise ValueError( + f"Invalid call_type {call_type}. Allowed values: {allowed_values}" + ) self.model = model self.messages = messages self.stream = stream - self.start_time = start_time # log the call start time + self.start_time = start_time # log the call start time self.call_type = call_type self.litellm_call_id = litellm_call_id self.function_id = function_id - self.streaming_chunks = [] # for generating complete stream response + self.streaming_chunks = [] # for generating complete stream response self.model_call_details = {} - - def update_environment_variables(self, model, user, optional_params, litellm_params, **additional_params): + + def update_environment_variables( + self, model, user, optional_params, litellm_params, **additional_params + ): self.optional_params = optional_params self.model = model self.user = user @@ -649,10 +734,10 @@ class Logging: "user": user, "call_type": str(self.call_type), **self.optional_params, - **additional_params + **additional_params, } - def _pre_call(self, input, api_key, model=None, additional_args={}): + def _pre_call(self, input, api_key, model=None, additional_args={}): """ Common helper function across the sync + async pre-call function """ @@ -662,31 +747,43 @@ class Logging: self.model_call_details["additional_args"] = additional_args self.model_call_details["log_event_type"] = "pre_api_call" if ( - model - ): # if model name was changes pre-call, overwrite the initial model call name with the new one - self.model_call_details["model"] = model + model + ): # if model name was changes pre-call, overwrite the initial model call name with the new one + self.model_call_details["model"] = model def pre_call(self, input, api_key, model=None, additional_args={}): # Log the exact input to the LLM API - litellm.error_logs['PRE_CALL'] = locals() + litellm.error_logs["PRE_CALL"] = locals() try: - self._pre_call(input=input, api_key=api_key, model=model, additional_args=additional_args) + self._pre_call( + input=input, + api_key=api_key, + model=model, + additional_args=additional_args, + ) # User Logging -> if you pass in a custom logging function headers = additional_args.get("headers", {}) - if headers is None: + if headers is None: headers = {} data = additional_args.get("complete_input_dict", {}) api_base = additional_args.get("api_base", "") - masked_headers = {k: (v[:-20] + '*' * 20) if (isinstance(v, str) and len(v) > 20) else v for k, v in headers.items()} - formatted_headers = " ".join([f"-H '{k}: {v}'" for k, v in masked_headers.items()]) + masked_headers = { + k: (v[:-20] + "*" * 20) if (isinstance(v, str) and len(v) > 20) else v + for k, v in headers.items() + } + formatted_headers = " ".join( + [f"-H '{k}: {v}'" for k, v in masked_headers.items()] + ) print_verbose(f"PRE-API-CALL ADDITIONAL ARGS: {additional_args}") curl_command = "\n\nPOST Request Sent from LiteLLM:\n" curl_command += "curl -X POST \\\n" curl_command += f"{api_base} \\\n" - curl_command += f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else "" + curl_command += ( + f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else "" + ) curl_command += f"-d '{str(data)}'\n" if additional_args.get("request_str", None) is not None: # print the sagemaker / bedrock client request @@ -707,10 +804,17 @@ class Logging: if litellm.max_budget and self.stream: start_time = self.start_time - end_time = self.start_time # no time has passed as the call hasn't been made yet + end_time = ( + self.start_time + ) # no time has passed as the call hasn't been made yet time_diff = (end_time - start_time).total_seconds() float_diff = float(time_diff) - litellm._current_cost += litellm.completion_cost(model=self.model, prompt="".join(message["content"] for message in self.messages), completion="", total_time=float_diff) + litellm._current_cost += litellm.completion_cost( + model=self.model, + prompt="".join(message["content"] for message in self.messages), + completion="", + total_time=float_diff, + ) # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made for callback in litellm.input_callback: @@ -729,7 +833,9 @@ class Logging: ) elif callback == "lite_debugger": - print_verbose(f"reaches litedebugger for logging! - model_call_details {self.model_call_details}") + print_verbose( + f"reaches litedebugger for logging! - model_call_details {self.model_call_details}" + ) model = self.model_call_details["model"] messages = self.model_call_details["input"] print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") @@ -741,7 +847,7 @@ class Logging: litellm_params=self.model_call_details["litellm_params"], optional_params=self.model_call_details["optional_params"], print_verbose=print_verbose, - call_type=self.call_type + call_type=self.call_type, ) elif callback == "sentry" and add_breadcrumb: print_verbose("reaches sentry breadcrumbing") @@ -750,19 +856,19 @@ class Logging: message=f"Model Call Details pre-call: {self.model_call_details}", level="info", ) - elif isinstance(callback, CustomLogger): # custom logger class + elif isinstance(callback, CustomLogger): # custom logger class callback.log_pre_api_call( model=self.model, messages=self.messages, kwargs=self.model_call_details, ) - elif callable(callback): # custom logger functions + elif callable(callback): # custom logger functions customLogger.log_input_event( model=self.model, messages=self.messages, kwargs=self.model_call_details, print_verbose=print_verbose, - callback_func=callback + callback_func=callback, ) except Exception as e: traceback.print_exc() @@ -784,37 +890,48 @@ class Logging: if capture_exception: # log this error to sentry for debugging capture_exception(e) - async def async_pre_call(self, result=None, start_time=None, end_time=None, **kwargs): + async def async_pre_call( + self, result=None, start_time=None, end_time=None, **kwargs + ): """ - Â Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. + Â Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ - start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result) + start_time, end_time, result = self._success_handler_helper_fn( + start_time=start_time, end_time=end_time, result=result + ) print_verbose(f"Async input callbacks: {litellm._async_input_callback}") for callback in litellm._async_input_callback: - try: - if isinstance(callback, CustomLogger): # custom logger class + try: + if isinstance(callback, CustomLogger): # custom logger class print_verbose(f"Async input callbacks: CustomLogger") - asyncio.create_task(callback.async_log_input_event( + asyncio.create_task( + callback.async_log_input_event( model=self.model, messages=self.messages, kwargs=self.model_call_details, - )) - if callable(callback): # custom logger functions + ) + ) + if callable(callback): # custom logger functions print_verbose(f"Async success callbacks: async_log_event") - asyncio.create_task(customLogger.async_log_input_event( - model=self.model, - messages=self.messages, - kwargs=self.model_call_details, - print_verbose=print_verbose, - callback_func=callback - )) - except: + asyncio.create_task( + customLogger.async_log_input_event( + model=self.model, + messages=self.messages, + kwargs=self.model_call_details, + print_verbose=print_verbose, + callback_func=callback, + ) + ) + except: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" ) - def post_call(self, original_response, input=None, api_key=None, additional_args={}): + + def post_call( + self, original_response, input=None, api_key=None, additional_args={} + ): # Log the exact result from the LLM API, for streaming - log the type of response received - litellm.error_logs['POST_CALL'] = locals() + litellm.error_logs["POST_CALL"] = locals() try: self.model_call_details["input"] = input self.model_call_details["api_key"] = api_key @@ -823,11 +940,15 @@ class Logging: self.model_call_details["log_event_type"] = "post_api_call" # User Logging -> if you pass in a custom logging function - print_verbose(f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n") + print_verbose( + f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n" + ) print_verbose( f"Logging Details Post-API Call: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}" ) - print_verbose(f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}") + print_verbose( + f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}" + ) if self.logger_fn and callable(self.logger_fn): try: self.logger_fn( @@ -837,7 +958,7 @@ class Logging: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) - + # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made for callback in litellm.input_callback: try: @@ -848,8 +969,8 @@ class Logging: original_response=original_response, litellm_call_id=self.litellm_params["litellm_call_id"], print_verbose=print_verbose, - call_type = self.call_type, - stream = self.stream, + call_type=self.call_type, + stream=self.stream, ) elif callback == "sentry" and add_breadcrumb: print_verbose("reaches sentry breadcrumbing") @@ -858,12 +979,12 @@ class Logging: message=f"Model Call Details post-call: {self.model_call_details}", level="info", ) - elif isinstance(callback, CustomLogger): # custom logger class + elif isinstance(callback, CustomLogger): # custom logger class callback.log_post_api_call( kwargs=self.model_call_details, response_obj=None, start_time=self.start_time, - end_time=None + end_time=None, ) except Exception as e: print_verbose( @@ -879,9 +1000,11 @@ class Logging: f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) pass - - def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None, cache_hit=None): - try: + + def _success_handler_helper_fn( + self, result=None, start_time=None, end_time=None, cache_hit=None + ): + try: if start_time is None: start_time = self.start_time if end_time is None: @@ -893,42 +1016,67 @@ class Logging: if litellm.max_budget and self.stream: time_diff = (end_time - start_time).total_seconds() float_diff = float(time_diff) - litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff) + litellm._current_cost += litellm.completion_cost( + model=self.model, + prompt="", + completion=result["content"], + total_time=float_diff, + ) return start_time, end_time, result - except Exception as e: + except Exception as e: print_verbose(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") - def success_handler(self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs): - print_verbose( - f"Logging Details LiteLLM-Success Call" - ) + def success_handler( + self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs + ): + print_verbose(f"Logging Details LiteLLM-Success Call") # print(f"original response in success handler: {self.model_call_details['original_response']}") try: - print_verbose(f"success callbacks: {litellm.success_callback}") + print_verbose(f"success callbacks: {litellm.success_callback}") ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None - if self.stream and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # only call stream chunk builder if it's not acompletion() - if result.choices[0].finish_reason is not None: # if it's the last chunk + if ( + self.stream + and self.model_call_details.get("litellm_params", {}).get( + "acompletion", False + ) + == False + ): # only call stream chunk builder if it's not acompletion() + if ( + result.choices[0].finish_reason is not None + ): # if it's the last chunk self.streaming_chunks.append(result) # print_verbose(f"final set of received chunks: {self.streaming_chunks}") - try: - complete_streaming_response = litellm.stream_chunk_builder(self.streaming_chunks, messages=self.model_call_details.get("messages", None)) - except: + try: + complete_streaming_response = litellm.stream_chunk_builder( + self.streaming_chunks, + messages=self.model_call_details.get("messages", None), + ) + except: complete_streaming_response = None else: self.streaming_chunks.append(result) - if complete_streaming_response: - self.model_call_details["complete_streaming_response"] = complete_streaming_response + if complete_streaming_response: + self.model_call_details[ + "complete_streaming_response" + ] = complete_streaming_response - start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit) + start_time, end_time, result = self._success_handler_helper_fn( + start_time=start_time, + end_time=end_time, + result=result, + cache_hit=cache_hit, + ) for callback in litellm.success_callback: try: if callback == "lite_debugger": print_verbose("reaches lite_debugger for logging!") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") - print_verbose(f"liteDebuggerClient details function {self.call_type} and stream set to {self.stream}") + print_verbose( + f"liteDebuggerClient details function {self.call_type} and stream set to {self.stream}" + ) liteDebuggerClient.log_event( end_user=kwargs.get("user", "default"), response_obj=result, @@ -936,8 +1084,8 @@ class Logging: end_time=end_time, litellm_call_id=self.litellm_call_id, print_verbose=print_verbose, - call_type = self.call_type, - stream = self.stream, + call_type=self.call_type, + stream=self.stream, ) if callback == "promptlayer": print_verbose("reaches promptlayer for logging!") @@ -950,8 +1098,8 @@ class Logging: ) if callback == "supabase": print_verbose("reaches supabase for logging!") - kwargs=self.model_call_details - + kwargs = self.model_call_details + # this only logs streaming once, complete_streaming_response exists i.e when stream ends if self.stream: if "complete_streaming_response" not in kwargs: @@ -959,7 +1107,7 @@ class Logging: else: print_verbose("reaches supabase for streaming logging!") result = kwargs["complete_streaming_response"] - + model = kwargs["model"] messages = kwargs["messages"] optional_params = kwargs.get("optional_params", {}) @@ -971,7 +1119,9 @@ class Logging: response_obj=result, start_time=start_time, end_time=end_time, - litellm_call_id=litellm_params.get("litellm_call_id", str(uuid.uuid4())), + litellm_call_id=litellm_params.get( + "litellm_call_id", str(uuid.uuid4()) + ), print_verbose=print_verbose, ) if callback == "wandb": @@ -1002,10 +1152,16 @@ class Logging: print_verbose("reaches llmonitor for logging!") model = self.model - input = self.model_call_details.get("messages", self.model_call_details.get("input", None)) + input = self.model_call_details.get( + "messages", self.model_call_details.get("input", None) + ) # if contains input, it's 'embedding', otherwise 'llm' - type = "embed" if self.call_type == CallTypes.embedding.value else "llm" + type = ( + "embed" + if self.call_type == CallTypes.embedding.value + else "llm" + ) llmonitorLogger.log_event( type=type, @@ -1035,8 +1191,10 @@ class Logging: global langFuseLogger print_verbose("reaches langfuse for logging!") kwargs = {} - for k, v in self.model_call_details.items(): - if k != "original_response": # copy.deepcopy raises errors as this could be a coroutine + for k, v in self.model_call_details.items(): + if ( + k != "original_response" + ): # copy.deepcopy raises errors as this could be a coroutine kwargs[k] = v # this only logs streaming once, complete_streaming_response exists i.e when stream ends if self.stream: @@ -1061,17 +1219,21 @@ class Logging: kwargs = self.model_call_details if self.stream: if "complete_streaming_response" not in kwargs: - print_verbose(f"success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n") + print_verbose( + f"success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n" + ) return else: - print_verbose("success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache") + print_verbose( + "success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache" + ) result = kwargs["complete_streaming_response"] # only add to cache once we have a complete streaming response litellm.cache.add_cache(result, **kwargs) if callback == "traceloop": deep_copy = {} - for k, v in self.model_call_details.items(): - if k != "original_response": + for k, v in self.model_call_details.items(): + if k != "original_response": deep_copy[k] = v traceloopLogger.log_event( kwargs=deep_copy, @@ -1080,18 +1242,32 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) - elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class - print_verbose(f"success callbacks: Running Custom Logger Class") + elif ( + isinstance(callback, CustomLogger) + and self.model_call_details.get("litellm_params", {}).get( + "acompletion", False + ) + == False + and self.model_call_details.get("litellm_params", {}).get( + "aembedding", False + ) + == False + ): # custom logger class + print_verbose(f"success callbacks: Running Custom Logger Class") if self.stream and complete_streaming_response is None: callback.log_stream_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, - end_time=end_time - ) + end_time=end_time, + ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = self.model_call_details.get("complete_streaming_response", {}) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} + ) result = self.model_call_details["complete_response"] callback.log_success_event( kwargs=self.model_call_details, @@ -1099,15 +1275,17 @@ class Logging: start_time=start_time, end_time=end_time, ) - if callable(callback): # custom logger functions - print_verbose(f"success callbacks: Running Custom Callback Function") + if callable(callback): # custom logger functions + print_verbose( + f"success callbacks: Running Custom Callback Function" + ) customLogger.log_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback + callback_func=callback, ) except Exception as e: @@ -1125,60 +1303,77 @@ class Logging: ) pass - async def async_success_handler(self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs): + async def async_success_handler( + self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs + ): """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ print_verbose(f"Async success callbacks: {litellm._async_success_callback}") ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None - if self.stream: - if result.choices[0].finish_reason is not None: # if it's the last chunk + if self.stream: + if result.choices[0].finish_reason is not None: # if it's the last chunk self.streaming_chunks.append(result) # print_verbose(f"final set of received chunks: {self.streaming_chunks}") - try: - complete_streaming_response = litellm.stream_chunk_builder(self.streaming_chunks, messages=self.model_call_details.get("messages", None)) + try: + complete_streaming_response = litellm.stream_chunk_builder( + self.streaming_chunks, + messages=self.model_call_details.get("messages", None), + ) except Exception as e: - print_verbose(f"Error occurred building stream chunk: {traceback.format_exc()}") + print_verbose( + f"Error occurred building stream chunk: {traceback.format_exc()}" + ) complete_streaming_response = None else: self.streaming_chunks.append(result) - if complete_streaming_response: + if complete_streaming_response: print_verbose("Async success callbacks: Got a complete streaming response") - self.model_call_details["complete_streaming_response"] = complete_streaming_response - start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit) + self.model_call_details[ + "complete_streaming_response" + ] = complete_streaming_response + start_time, end_time, result = self._success_handler_helper_fn( + start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit + ) for callback in litellm._async_success_callback: - try: + try: if callback == "cache" and litellm.cache is not None: # set_cache once complete streaming response is built print_verbose("async success_callback: reaches cache for logging!") kwargs = self.model_call_details if self.stream: if "complete_streaming_response" not in kwargs: - print_verbose(f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n") + print_verbose( + f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n" + ) return else: - print_verbose("async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache") + print_verbose( + "async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache" + ) result = kwargs["complete_streaming_response"] # only add to cache once we have a complete streaming response litellm.cache.add_cache(result, **kwargs) - if isinstance(callback, CustomLogger): # custom logger class + if isinstance(callback, CustomLogger): # custom logger class print_verbose(f"Async success callbacks: CustomLogger") if self.stream: if "complete_streaming_response" in self.model_call_details: await callback.async_log_success_event( kwargs=self.model_call_details, - response_obj=self.model_call_details["complete_streaming_response"], + response_obj=self.model_call_details[ + "complete_streaming_response" + ], start_time=start_time, end_time=end_time, ) - else: - await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function + else: + await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function kwargs=self.model_call_details, response_obj=result, start_time=start_time, - end_time=end_time - ) + end_time=end_time, + ) else: await callback.async_log_success_event( kwargs=self.model_call_details, @@ -1186,7 +1381,7 @@ class Logging: start_time=start_time, end_time=end_time, ) - if callable(callback): # custom logger functions + if callable(callback): # custom logger functions print_verbose(f"Async success callbacks: async_log_event") await customLogger.async_log_event( kwargs=self.model_call_details, @@ -1194,7 +1389,7 @@ class Logging: start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback + callback_func=callback, ) if callback == "dynamodb": global dynamoLogger @@ -1202,16 +1397,22 @@ class Logging: dynamoLogger = DyanmoDBLogger() if self.stream: if "complete_streaming_response" in self.model_call_details: - print_verbose("DynamoDB Logger: Got Stream Event - Completed Stream Response") + print_verbose( + "DynamoDB Logger: Got Stream Event - Completed Stream Response" + ) await dynamoLogger._async_log_event( kwargs=self.model_call_details, - response_obj=self.model_call_details["complete_streaming_response"], + response_obj=self.model_call_details[ + "complete_streaming_response" + ], start_time=start_time, end_time=end_time, - print_verbose=print_verbose + print_verbose=print_verbose, + ) + else: + print_verbose( + "DynamoDB Logger: Got Stream Event - No complete stream response as yet" ) - else: - print_verbose("DynamoDB Logger: Got Stream Event - No complete stream response as yet") else: await dynamoLogger._async_log_event( kwargs=self.model_call_details, @@ -1224,8 +1425,10 @@ class Logging: global langFuseLogger print_verbose("reaches langfuse for logging!") kwargs = {} - for k, v in self.model_call_details.items(): - if k != "original_response": # copy.deepcopy raises errors as this could be a coroutine + for k, v in self.model_call_details.items(): + if ( + k != "original_response" + ): # copy.deepcopy raises errors as this could be a coroutine kwargs[k] = v # this only logs streaming once, complete_streaming_response exists i.e when stream ends if self.stream: @@ -1243,13 +1446,15 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) - except: + except: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" ) pass - def _failure_handler_helper_fn(self, exception, traceback_exception, start_time=None, end_time=None): + def _failure_handler_helper_fn( + self, exception, traceback_exception, start_time=None, end_time=None + ): if start_time is None: start_time = self.start_time if end_time is None: @@ -1266,49 +1471,58 @@ class Logging: self.model_call_details.setdefault("original_response", None) return start_time, end_time - def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): - print_verbose( - f"Logging Details LiteLLM-Failure Call" - ) + def failure_handler( + self, exception, traceback_exception, start_time=None, end_time=None + ): + print_verbose(f"Logging Details LiteLLM-Failure Call") try: - start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time) - result = None # result sent to all loggers, init this to None incase it's not created + start_time, end_time = self._failure_handler_helper_fn( + exception=exception, + traceback_exception=traceback_exception, + start_time=start_time, + end_time=end_time, + ) + result = None # result sent to all loggers, init this to None incase it's not created for callback in litellm.failure_callback: try: if callback == "lite_debugger": - print_verbose("reaches lite_debugger for logging!") - print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") - result = { - "model": self.model, - "created": time.time(), - "error": traceback_exception, - "usage": { - "prompt_tokens": prompt_token_calculator( - self.model, messages=self.messages - ), - "completion_tokens": 0, - }, - } - liteDebuggerClient.log_event( - model=self.model, - messages=self.messages, - end_user=self.model_call_details.get("user", "default"), - response_obj=result, - start_time=start_time, - end_time=end_time, - litellm_call_id=self.litellm_call_id, - print_verbose=print_verbose, - call_type = self.call_type, - stream = self.stream, - ) + print_verbose("reaches lite_debugger for logging!") + print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") + result = { + "model": self.model, + "created": time.time(), + "error": traceback_exception, + "usage": { + "prompt_tokens": prompt_token_calculator( + self.model, messages=self.messages + ), + "completion_tokens": 0, + }, + } + liteDebuggerClient.log_event( + model=self.model, + messages=self.messages, + end_user=self.model_call_details.get("user", "default"), + response_obj=result, + start_time=start_time, + end_time=end_time, + litellm_call_id=self.litellm_call_id, + print_verbose=print_verbose, + call_type=self.call_type, + stream=self.stream, + ) elif callback == "llmonitor": print_verbose("reaches llmonitor for logging error!") model = self.model input = self.model_call_details["input"] - - type = "embed" if self.call_type == CallTypes.embedding.value else "llm" + + type = ( + "embed" + if self.call_type == CallTypes.embedding.value + else "llm" + ) llmonitorLogger.log_event( type=type, @@ -1327,17 +1541,29 @@ class Logging: if capture_exception: capture_exception(exception) else: - print_verbose(f"capture exception not initialized: {capture_exception}") - elif callable(callback): # custom logger functions + print_verbose( + f"capture exception not initialized: {capture_exception}" + ) + elif callable(callback): # custom logger functions customLogger.log_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback + callback_func=callback, ) - elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class + elif ( + isinstance(callback, CustomLogger) + and self.model_call_details.get("litellm_params", {}).get( + "acompletion", False + ) + == False + and self.model_call_details.get("litellm_params", {}).get( + "aembedding", False + ) + == False + ): # custom logger class callback.log_failure_event( start_time=start_time, end_time=end_time, @@ -1359,37 +1585,43 @@ class Logging: ) pass - async def async_failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): + async def async_failure_handler( + self, exception, traceback_exception, start_time=None, end_time=None + ): """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ - start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time) - result = None # result sent to all loggers, init this to None incase it's not created + start_time, end_time = self._failure_handler_helper_fn( + exception=exception, + traceback_exception=traceback_exception, + start_time=start_time, + end_time=end_time, + ) + result = None # result sent to all loggers, init this to None incase it's not created for callback in litellm._async_failure_callback: - try: - if isinstance(callback, CustomLogger): # custom logger class + try: + if isinstance(callback, CustomLogger): # custom logger class await callback.async_log_failure_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, ) - if callable(callback): # custom logger functions + if callable(callback): # custom logger functions await customLogger.async_log_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback - ) - except Exception as e: + callback_func=callback, + ) + except Exception as e: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" ) - def exception_logging( additional_args={}, logger_fn=None, @@ -1422,53 +1654,59 @@ def exception_logging( ####### RULES ################### -class Rules: + +class Rules: """ Fail calls based on the input or llm api output - Example usage: - import litellm - def my_custom_rule(input): # receives the model response - if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer - return False - return True - + Example usage: + import litellm + def my_custom_rule(input): # receives the model response + if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer + return False + return True + litellm.post_call_rules = [my_custom_rule] # have these be functions that can be called to fail a call - response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user", - "content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"]) + response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user", + "content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"]) """ + def __init__(self) -> None: pass - def pre_call_rules(self, input: str, model: str): - for rule in litellm.pre_call_rules: - if callable(rule): + def pre_call_rules(self, input: str, model: str): + for rule in litellm.pre_call_rules: + if callable(rule): decision = rule(input) if decision is False: - raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore - return True + raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore + return True - def post_call_rules(self, input: str, model: str): - for rule in litellm.post_call_rules: - if callable(rule): + def post_call_rules(self, input: str, model: str): + for rule in litellm.post_call_rules: + if callable(rule): decision = rule(input) if decision is False: - raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore - return True + raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore + return True + ####### CLIENT ################### # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking def client(original_function): global liteDebuggerClient, get_all_keys rules_obj = Rules() + def function_setup( start_time, *args, **kwargs ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. try: global callback_list, add_breadcrumb, user_logger_fn, Logging function_id = kwargs["id"] if "id" in kwargs else None - if litellm.use_client or ("use_client" in kwargs and kwargs["use_client"] == True): + if litellm.use_client or ( + "use_client" in kwargs and kwargs["use_client"] == True + ): print_verbose(f"litedebugger initialized") if "lite_debugger" not in litellm.input_callback: litellm.input_callback.append("lite_debugger") @@ -1476,8 +1714,8 @@ def client(original_function): litellm.success_callback.append("lite_debugger") if "lite_debugger" not in litellm.failure_callback: litellm.failure_callback.append("lite_debugger") - if len(litellm.callbacks) > 0: - for callback in litellm.callbacks: + if len(litellm.callbacks) > 0: + for callback in litellm.callbacks: if callback not in litellm.input_callback: litellm.input_callback.append(callback) if callback not in litellm.success_callback: @@ -1488,7 +1726,9 @@ def client(original_function): litellm._async_success_callback.append(callback) if callback not in litellm._async_failure_callback: litellm._async_failure_callback.append(callback) - print_verbose(f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}") + print_verbose( + f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}" + ) if ( len(litellm.input_callback) > 0 or len(litellm.success_callback) > 0 @@ -1501,10 +1741,7 @@ def client(original_function): + litellm.failure_callback ) ) - set_callbacks( - callback_list=callback_list, - function_id=function_id - ) + set_callbacks(callback_list=callback_list, function_id=function_id) ## ASYNC CALLBACKS if len(litellm.input_callback) > 0: removed_async_items = [] @@ -1517,10 +1754,10 @@ def client(original_function): for index in reversed(removed_async_items): litellm.input_callback.pop(index) - if len(litellm.success_callback) > 0: + if len(litellm.success_callback) > 0: removed_async_items = [] - for index, callback in enumerate(litellm.success_callback): - if inspect.iscoroutinefunction(callback): + for index, callback in enumerate(litellm.success_callback): + if inspect.iscoroutinefunction(callback): litellm._async_success_callback.append(callback) removed_async_items.append(index) elif callback == "dynamodb": @@ -1528,7 +1765,9 @@ def client(original_function): # we only support async dynamo db logging for acompletion/aembedding since that's used on proxy litellm._async_success_callback.append(callback) removed_async_items.append(index) - elif callback == "langfuse" and inspect.iscoroutinefunction(original_function): + elif callback == "langfuse" and inspect.iscoroutinefunction( + original_function + ): # use async success callback for langfuse if this is litellm.acompletion(). Streaming logging does not work otherwise litellm._async_success_callback.append(callback) removed_async_items.append(index) @@ -1536,11 +1775,11 @@ def client(original_function): # Pop the async items from success_callback in reverse order to avoid index issues for index in reversed(removed_async_items): litellm.success_callback.pop(index) - - if len(litellm.failure_callback) > 0: + + if len(litellm.failure_callback) > 0: removed_async_items = [] - for index, callback in enumerate(litellm.failure_callback): - if inspect.iscoroutinefunction(callback): + for index, callback in enumerate(litellm.failure_callback): + if inspect.iscoroutinefunction(callback): litellm._async_failure_callback.append(callback) removed_async_items.append(index) @@ -1560,38 +1799,75 @@ def client(original_function): # INIT LOGGER - for user-specified integrations model = args[0] if len(args) > 0 else kwargs.get("model", None) call_type = original_function.__name__ - if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value: + if ( + call_type == CallTypes.completion.value + or call_type == CallTypes.acompletion.value + ): messages = None if len(args) > 1: - messages = args[1] + messages = args[1] elif kwargs.get("messages", None): messages = kwargs["messages"] - ### PRE-CALL RULES ### - if isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], dict) and "content" in messages[0]: - rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model) - elif call_type == CallTypes.embedding.value or call_type == CallTypes.aembedding.value: + ### PRE-CALL RULES ### + if ( + isinstance(messages, list) + and len(messages) > 0 + and isinstance(messages[0], dict) + and "content" in messages[0] + ): + rules_obj.pre_call_rules( + input="".join( + m["content"] + for m in messages + if isinstance(m["content"], str) + ), + model=model, + ) + elif ( + call_type == CallTypes.embedding.value + or call_type == CallTypes.aembedding.value + ): messages = args[1] if len(args) > 1 else kwargs["input"] - elif call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value: + elif ( + call_type == CallTypes.image_generation.value + or call_type == CallTypes.aimage_generation.value + ): messages = args[0] if len(args) > 0 else kwargs["prompt"] stream = True if "stream" in kwargs and kwargs["stream"] == True else False - logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time) + logging_obj = Logging( + model=model, + messages=messages, + stream=stream, + litellm_call_id=kwargs["litellm_call_id"], + function_id=function_id, + call_type=call_type, + start_time=start_time, + ) return logging_obj - except Exception as e: + except Exception as e: import logging - logging.debug(f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}") + + logging.debug( + f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}" + ) raise e - + def post_call_processing(original_response, model): - try: - if original_response is None: + try: + if original_response is None: pass - else: + else: call_type = original_function.__name__ - if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value: - model_response = original_response['choices'][0]['message']['content'] - ### POST-CALL RULES ### + if ( + call_type == CallTypes.completion.value + or call_type == CallTypes.acompletion.value + ): + model_response = original_response["choices"][0]["message"][ + "content" + ] + ### POST-CALL RULES ### rules_obj.post_call_rules(input=model_response, model=model) - except Exception as e: + except Exception as e: raise e def crash_reporting(*args, **kwargs): @@ -1625,7 +1901,7 @@ def client(original_function): model = args[0] if len(args) > 0 else kwargs["model"] except: model = None - call_type = original_function.__name__ + call_type = original_function.__name__ if call_type != CallTypes.image_generation.value: raise ValueError("model param not passed in.") @@ -1635,89 +1911,135 @@ def client(original_function): kwargs["litellm_logging_obj"] = logging_obj # CHECK FOR 'os.environ/' in kwargs - for k,v in kwargs.items(): + for k, v in kwargs.items(): if v is not None and isinstance(v, str) and v.startswith("os.environ/"): kwargs[k] = litellm.get_secret(v) - # [OPTIONAL] CHECK BUDGET + # [OPTIONAL] CHECK BUDGET if litellm.max_budget: if litellm._current_cost > litellm.max_budget: - raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget) + raise BudgetExceededError( + current_cost=litellm._current_cost, + max_budget=litellm.max_budget, + ) # [OPTIONAL] CHECK CACHE - print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}") - # if caching is false, don't run this - if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function + print_verbose( + f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" + ) + # if caching is false, don't run this + if ( + kwargs.get("caching", None) is None and litellm.cache is not None + ) or kwargs.get( + "caching", False + ) == True: # allow users to control returning cached responses from the completion function # checking cache print_verbose(f"INSIDE CHECKING CACHE") - if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: + if ( + litellm.cache is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ): print_verbose(f"Checking Cache") preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key + kwargs[ + "preset_cache_key" + ] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result != None: - if "detail" in cached_result: - # implies an error occurred + if "detail" in cached_result: + # implies an error occurred pass - else: + else: call_type = original_function.__name__ - print_verbose(f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}") - if call_type == CallTypes.completion.value and isinstance(cached_result, dict): - return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse(), stream = kwargs.get("stream", False)) - elif call_type == CallTypes.embedding.value and isinstance(cached_result, dict): - return convert_to_model_response_object(response_object=cached_result, response_type="embedding") - else: + print_verbose( + f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}" + ) + if call_type == CallTypes.completion.value and isinstance( + cached_result, dict + ): + return convert_to_model_response_object( + response_object=cached_result, + model_response_object=ModelResponse(), + stream=kwargs.get("stream", False), + ) + elif call_type == CallTypes.embedding.value and isinstance( + cached_result, dict + ): + return convert_to_model_response_object( + response_object=cached_result, + response_type="embedding", + ) + else: return cached_result # MODEL CALL result = original_function(*args, **kwargs) end_time = datetime.datetime.now() if "stream" in kwargs and kwargs["stream"] == True: # TODO: Add to cache for streaming - if "complete_response" in kwargs and kwargs["complete_response"] == True: + if ( + "complete_response" in kwargs + and kwargs["complete_response"] == True + ): chunks = [] for idx, chunk in enumerate(result): chunks.append(chunk) - return litellm.stream_chunk_builder(chunks, messages=kwargs.get("messages", None)) - else: + return litellm.stream_chunk_builder( + chunks, messages=kwargs.get("messages", None) + ) + else: return result - elif "acompletion" in kwargs and kwargs["acompletion"] == True: + elif "acompletion" in kwargs and kwargs["acompletion"] == True: return result - elif "aembedding" in kwargs and kwargs["aembedding"] == True: + elif "aembedding" in kwargs and kwargs["aembedding"] == True: return result elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True: return result - - ### POST-CALL RULES ### + + ### POST-CALL RULES ### post_call_processing(original_response=result, model=model or None) # [OPTIONAL] ADD TO CACHE - if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: + if ( + litellm.cache is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ): litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated print_verbose(f"Wrapper: Completed Call, calling success_handler") - threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() + threading.Thread( + target=logging_obj.success_handler, args=(result, start_time, end_time) + ).start() # RETURN RESULT - result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai + result._response_ms = ( + end_time - start_time + ).total_seconds() * 1000 # return response latency in ms like openai return result except Exception as e: call_type = original_function.__name__ if call_type == CallTypes.completion.value: num_retries = ( - kwargs.get("num_retries", None) - or litellm.num_retries - or None + kwargs.get("num_retries", None) or litellm.num_retries or None + ) + litellm.num_retries = ( + None # set retries to None to prevent infinite loops + ) + context_window_fallback_dict = kwargs.get( + "context_window_fallback_dict", {} ) - litellm.num_retries = None # set retries to None to prevent infinite loops - context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {}) - if num_retries: - if (isinstance(e, openai.APIError) - or isinstance(e, openai.Timeout)): + if num_retries: + if isinstance(e, openai.APIError) or isinstance(e, openai.Timeout): kwargs["num_retries"] = num_retries return litellm.completion_with_retries(*args, **kwargs) - elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict: + elif ( + isinstance(e, litellm.exceptions.ContextWindowExceededError) + and context_window_fallback_dict + and model in context_window_fallback_dict + ): if len(args) > 0: - args[0] = context_window_fallback_dict[model] + args[0] = context_window_fallback_dict[model] else: kwargs["model"] = context_window_fallback_dict[model] return original_function(*args, **kwargs) @@ -1726,7 +2048,9 @@ def client(original_function): end_time = datetime.datetime.now() # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated if logging_obj: - logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this! + logging_obj.failure_handler( + e, traceback_exception, start_time, end_time + ) # DO NOT MAKE THREADED - router retry fallback relies on this! my_thread = threading.Thread( target=handle_failure, args=(e, traceback_exception, start_time, end_time, args, kwargs), @@ -1738,8 +2062,8 @@ def client(original_function): ): # make it easy to get to the debugger logs if you've initialized it e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" raise e - - async def wrapper_async(*args, **kwargs): + + async def wrapper_async(*args, **kwargs): start_time = datetime.datetime.now() result = None logging_obj = kwargs.get("litellm_logging_obj", None) @@ -1750,115 +2074,215 @@ def client(original_function): model = args[0] if len(args) > 0 else kwargs["model"] except: raise ValueError("model param not passed in.") - - try: + + try: if logging_obj is None: logging_obj = function_setup(start_time, *args, **kwargs) kwargs["litellm_logging_obj"] = logging_obj - # [OPTIONAL] CHECK BUDGET + # [OPTIONAL] CHECK BUDGET if litellm.max_budget: if litellm._current_cost > litellm.max_budget: - raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget) + raise BudgetExceededError( + current_cost=litellm._current_cost, + max_budget=litellm.max_budget, + ) # [OPTIONAL] CHECK CACHE print_verbose(f"litellm.cache: {litellm.cache}") - print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}") - # if caching is false, don't run this - if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function + print_verbose( + f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" + ) + # if caching is false, don't run this + if ( + kwargs.get("caching", None) is None and litellm.cache is not None + ) or kwargs.get( + "caching", False + ) == True: # allow users to control returning cached responses from the completion function # checking cache print_verbose(f"INSIDE CHECKING CACHE") - if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: + if ( + litellm.cache is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ): print_verbose(f"Checking Cache") cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result != None: print_verbose(f"Cache Hit!") call_type = original_function.__name__ - if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict): + if call_type == CallTypes.acompletion.value and isinstance( + cached_result, dict + ): if kwargs.get("stream", False) == True: cached_result = convert_to_streaming_response_async( response_object=cached_result, ) else: - cached_result = convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) - elif call_type == CallTypes.aembedding.value and isinstance(cached_result, dict): - cached_result = convert_to_model_response_object(response_object=cached_result, model_response_object=EmbeddingResponse(), response_type="embedding") - # LOG SUCCESS + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=ModelResponse(), + ) + elif call_type == CallTypes.aembedding.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=EmbeddingResponse(), + response_type="embedding", + ) + # LOG SUCCESS cache_hit = True end_time = datetime.datetime.now() - model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(model=model, custom_llm_provider=kwargs.get('custom_llm_provider', None), api_base=kwargs.get('api_base', None), api_key=kwargs.get('api_key', None)) - print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}") - logging_obj.update_environment_variables(model=model, user=kwargs.get('user', None), optional_params={}, litellm_params={"logger_fn": kwargs.get('logger_fn', None), "acompletion": True, "metadata": kwargs.get("metadata", {}), "model_info": kwargs.get("model_info", {}), "proxy_server_request": kwargs.get("proxy_server_request", None), "preset_cache_key": kwargs.get("preset_cache_key", None), "stream_response": kwargs.get("stream_response", {})}, input=kwargs.get('messages', ""), api_key=kwargs.get('api_key', None), original_response=str(cached_result), additional_args=None, stream=kwargs.get('stream', False)) - asyncio.create_task(logging_obj.async_success_handler(cached_result, start_time, end_time, cache_hit)) - threading.Thread(target=logging_obj.success_handler, args=(cached_result, start_time, end_time, cache_hit)).start() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params={ + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": True, + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get( + "proxy_server_request", None + ), + "preset_cache_key": kwargs.get( + "preset_cache_key", None + ), + "stream_response": kwargs.get("stream_response", {}), + }, + input=kwargs.get("messages", ""), + api_key=kwargs.get("api_key", None), + original_response=str(cached_result), + additional_args=None, + stream=kwargs.get("stream", False), + ) + asyncio.create_task( + logging_obj.async_success_handler( + cached_result, start_time, end_time, cache_hit + ) + ) + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() return cached_result # MODEL CALL result = await original_function(*args, **kwargs) end_time = datetime.datetime.now() if "stream" in kwargs and kwargs["stream"] == True: - if "complete_response" in kwargs and kwargs["complete_response"] == True: + if ( + "complete_response" in kwargs + and kwargs["complete_response"] == True + ): chunks = [] for idx, chunk in enumerate(result): chunks.append(chunk) - return litellm.stream_chunk_builder(chunks, messages=kwargs.get("messages", None)) - else: + return litellm.stream_chunk_builder( + chunks, messages=kwargs.get("messages", None) + ) + else: return result - - ### POST-CALL RULES ### + + ### POST-CALL RULES ### post_call_processing(original_response=result, model=model) # [OPTIONAL] ADD TO CACHE - if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: - if isinstance(result, litellm.ModelResponse) or isinstance(result, litellm.EmbeddingResponse): - asyncio.create_task(litellm.cache._async_add_cache(result.json(), *args, **kwargs)) + if ( + litellm.cache is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ): + if isinstance(result, litellm.ModelResponse) or isinstance( + result, litellm.EmbeddingResponse + ): + asyncio.create_task( + litellm.cache._async_add_cache(result.json(), *args, **kwargs) + ) else: - asyncio.create_task(litellm.cache._async_add_cache(result, *args, **kwargs)) + asyncio.create_task( + litellm.cache._async_add_cache(result, *args, **kwargs) + ) # LOG SUCCESS - handle streaming success logging in the _next_ object - print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}") - asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time)) - threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + asyncio.create_task( + logging_obj.async_success_handler(result, start_time, end_time) + ) + threading.Thread( + target=logging_obj.success_handler, args=(result, start_time, end_time) + ).start() # RETURN RESULT if isinstance(result, ModelResponse): - result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai + result._response_ms = ( + end_time - start_time + ).total_seconds() * 1000 # return response latency in ms like openai return result - except Exception as e: + except Exception as e: traceback_exception = traceback.format_exc() crash_reporting(*args, **kwargs, exception=traceback_exception) end_time = datetime.datetime.now() if logging_obj: try: - logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this! - except Exception as e: - raise e - try: - await logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time) + logging_obj.failure_handler( + e, traceback_exception, start_time, end_time + ) # DO NOT MAKE THREADED - router retry fallback relies on this! except Exception as e: raise e - + try: + await logging_obj.async_failure_handler( + e, traceback_exception, start_time, end_time + ) + except Exception as e: + raise e + call_type = original_function.__name__ if call_type == CallTypes.acompletion.value: num_retries = ( - kwargs.get("num_retries", None) - or litellm.num_retries - or None + kwargs.get("num_retries", None) or litellm.num_retries or None ) - litellm.num_retries = None # set retries to None to prevent infinite loops - context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {}) - - if num_retries: - try: + litellm.num_retries = ( + None # set retries to None to prevent infinite loops + ) + context_window_fallback_dict = kwargs.get( + "context_window_fallback_dict", {} + ) + + if num_retries: + try: kwargs["num_retries"] = num_retries kwargs["original_function"] = original_function - if (isinstance(e, openai.RateLimitError)): # rate limiting specific error + if isinstance( + e, openai.RateLimitError + ): # rate limiting specific error kwargs["retry_strategy"] = "exponential_backoff_retry" - elif (isinstance(e, openai.APIError)): # generic api error + elif isinstance(e, openai.APIError): # generic api error kwargs["retry_strategy"] = "constant_retry" return await litellm.acompletion_with_retries(*args, **kwargs) except: pass - elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict: + elif ( + isinstance(e, litellm.exceptions.ContextWindowExceededError) + and context_window_fallback_dict + and model in context_window_fallback_dict + ): if len(args) > 0: - args[0] = context_window_fallback_dict[model] + args[0] = context_window_fallback_dict[model] else: kwargs["model"] = context_window_fallback_dict[model] return await original_function(*args, **kwargs) @@ -1872,6 +2296,7 @@ def client(original_function): else: return wrapper + ####### USAGE CALCULATOR ################ @@ -1879,7 +2304,10 @@ def client(original_function): # only used for together_computer LLMs def get_model_params_and_category(model_name): import re - params_match = re.search(r'(\d+b)', model_name) # catch all decimals like 3b, 70b, etc + + params_match = re.search( + r"(\d+b)", model_name + ) # catch all decimals like 3b, 70b, etc category = None if params_match != None: params_match = params_match.group(1) @@ -1900,30 +2328,36 @@ def get_model_params_and_category(model_name): return None + def get_replicate_completion_pricing(completion_response=None, total_time=0.0): # see https://replicate.com/pricing a100_40gb_price_per_second_public = 0.001150 # for all litellm currently supported LLMs, almost all requests go to a100_80gb - a100_80gb_price_per_second_public = 0.001400 # assume all calls sent to A100 80GB for now + a100_80gb_price_per_second_public = ( + 0.001400 # assume all calls sent to A100 80GB for now + ) if total_time == 0.0: - start_time = completion_response['created'] + start_time = completion_response["created"] end_time = completion_response["ended"] total_time = end_time - start_time - return a100_80gb_price_per_second_public*total_time + return a100_80gb_price_per_second_public * total_time -def _select_tokenizer(model: str): - # cohere +def _select_tokenizer(model: str): + # cohere import pkg_resources + if model in litellm.cohere_models: tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly") return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} - # anthropic + # anthropic elif model in litellm.anthropic_models: # Read the JSON file - filename = pkg_resources.resource_filename(__name__, 'llms/tokenizers/anthropic_tokenizer.json') - with open(filename, 'r') as f: + filename = pkg_resources.resource_filename( + __name__, "llms/tokenizers/anthropic_tokenizer.json" + ) + with open(filename, "r") as f: json_data = json.load(f) # Decode the JSON data from utf-8 json_data_decoded = json.dumps(json_data, ensure_ascii=False) @@ -1932,15 +2366,16 @@ def _select_tokenizer(model: str): # load tokenizer tokenizer = Tokenizer.from_str(json_str) return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} - # llama2 - elif "llama-2" in model.lower(): + # llama2 + elif "llama-2" in model.lower(): tokenizer = Tokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} # default - tiktoken - else: + else: return {"type": "openai_tokenizer", "tokenizer": encoding} -def encode(model: str, text: str): + +def encode(model: str, text: str): """ Encodes the given text using the specified model. @@ -1955,12 +2390,18 @@ def encode(model: str, text: str): enc = tokenizer_json["tokenizer"].encode(text) return enc -def decode(model: str, tokens: List[int]): + +def decode(model: str, tokens: List[int]): tokenizer_json = _select_tokenizer(model=model) dec = tokenizer_json["tokenizer"].decode(tokens) return dec -def openai_token_counter(messages: Optional[list]=None, model="gpt-3.5-turbo-0613", text: Optional[str]= None): + +def openai_token_counter( + messages: Optional[list] = None, + model="gpt-3.5-turbo-0613", + text: Optional[str] = None, +): """ Return the number of tokens used by a list of messages. @@ -1972,7 +2413,9 @@ def openai_token_counter(messages: Optional[list]=None, model="gpt-3.5-turbo-061 print_verbose("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") if model == "gpt-3.5-turbo-0301": - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_message = ( + 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + ) tokens_per_name = -1 # if there's a name, the role is omitted elif model in litellm.open_ai_chat_completion_models: tokens_per_message = 3 @@ -1983,9 +2426,9 @@ def openai_token_counter(messages: Optional[list]=None, model="gpt-3.5-turbo-061 ) num_tokens = 0 - if text: + if text: num_tokens = len(encoding.encode(text, disallowed_special=())) - elif messages: + elif messages: for message in messages: num_tokens += tokens_per_message for key, value in message.items(): @@ -1995,7 +2438,8 @@ def openai_token_counter(messages: Optional[list]=None, model="gpt-3.5-turbo-061 num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens -def token_counter(model="", text=None, messages: Optional[List] = None): + +def token_counter(model="", text=None, messages: Optional[List] = None): """ Count the number of tokens in a given text using a specified model. @@ -2011,24 +2455,24 @@ def token_counter(model="", text=None, messages: Optional[List] = None): if text == None: if messages is not None: print_verbose(f"token_counter messages received: {messages}") - text = "" - for message in messages: + text = "" + for message in messages: if message.get("content", None): text += message["content"] - if 'tool_calls' in message: - for tool_call in message['tool_calls']: - if 'function' in tool_call: - function_arguments = tool_call['function']['arguments'] + if "tool_calls" in message: + for tool_call in message["tool_calls"]: + if "function" in tool_call: + function_arguments = tool_call["function"]["arguments"] text += function_arguments else: raise ValueError("text and messages cannot both be None") num_tokens = 0 if model is not None: tokenizer_json = _select_tokenizer(model=model) - if tokenizer_json["type"] == "huggingface_tokenizer": + if tokenizer_json["type"] == "huggingface_tokenizer": enc = tokenizer_json["tokenizer"].encode(text) num_tokens = len(enc.ids) - elif tokenizer_json["type"] == "openai_tokenizer": + elif tokenizer_json["type"] == "openai_tokenizer": if model in litellm.open_ai_chat_completion_models: num_tokens = openai_token_counter(text=text, model=model) else: @@ -2047,7 +2491,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): model (str): The name of the model to use. Default is "" prompt_tokens (int): The number of tokens in the prompt. completion_tokens (int): The number of tokens in the completion. - + Returns: tuple: A tuple containing the cost in USD dollars for prompt tokens and completion tokens, respectively. """ @@ -2059,7 +2503,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): azure_llms = { "gpt-35-turbo": "azure/gpt-3.5-turbo", "gpt-35-turbo-16k": "azure/gpt-3.5-turbo-16k", - "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct" + "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct", } if model in model_cost_ref: prompt_tokens_cost_usd_dollar = ( @@ -2075,7 +2519,8 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens ) completion_tokens_cost_usd_dollar = ( - model_cost_ref["ft:gpt-3.5-turbo"]["output_cost_per_token"] * completion_tokens + model_cost_ref["ft:gpt-3.5-turbo"]["output_cost_per_token"] + * completion_tokens ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif model in azure_llms: @@ -2103,22 +2548,22 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): def completion_cost( - completion_response=None, - model=None, - prompt="", - messages: List = [], - completion="", - total_time=0.0, # used for replicate - ): + completion_response=None, + model=None, + prompt="", + messages: List = [], + completion="", + total_time=0.0, # used for replicate +): """ Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm. Parameters: completion_response (litellm.ModelResponses): [Required] The response received from a LiteLLM completion request. - + [OPTIONAL PARAMS] model (str): Optional. The name of the language model used in the completion calls - prompt (str): Optional. The input prompt passed to the llm + prompt (str): Optional. The input prompt passed to the llm completion (str): Optional. The output completion text from the llm total_time (float): Optional. (Only used for Replicate LLMs) The total time used for the request in seconds @@ -2141,41 +2586,46 @@ def completion_cost( completion_tokens = 0 if completion_response is not None: # get input/output tokens from completion_response - prompt_tokens = completion_response['usage']['prompt_tokens'] - completion_tokens = completion_response['usage']['completion_tokens'] - model = model or completion_response['model'] # check if user passed an override for model, if it's none check completion_response['model'] + prompt_tokens = completion_response["usage"]["prompt_tokens"] + completion_tokens = completion_response["usage"]["completion_tokens"] + model = ( + model or completion_response["model"] + ) # check if user passed an override for model, if it's none check completion_response['model'] else: if len(messages) > 0: prompt_tokens = token_counter(model=model, messages=messages) - elif len(prompt) > 0: + elif len(prompt) > 0: prompt_tokens = token_counter(model=model, text=prompt) completion_tokens = token_counter(model=model, text=completion) - + # Calculate cost based on prompt_tokens, completion_tokens if "togethercomputer" in model: # together ai prices based on size of llm - # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json + # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json model = get_model_params_and_category(model) # replicate llms are calculate based on time for request running # see https://replicate.com/pricing - elif ( - model in litellm.replicate_models or - "replicate" in model - ): + elif model in litellm.replicate_models or "replicate" in model: return get_replicate_completion_pricing(completion_response, total_time) - prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token( - model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ( + prompt_tokens_cost_usd_dollar, + completion_tokens_cost_usd_dollar, + ) = cost_per_token( + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, ) return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar except: - return 0.0 # this should not block a users execution path + return 0.0 # this should not block a users execution path + ####### HELPER FUNCTIONS ################ -def register_model(model_cost: Union[str, dict]): +def register_model(model_cost: Union[str, dict]): """ - Register new / Override existing models (and their pricing) to specific providers. + Register new / Override existing models (and their pricing) to specific providers. Provide EITHER a model cost dictionary or a url to a hosted json blob - Example usage: + Example usage: model_cost_dict = { "gpt-4": { "max_tokens": 8192, @@ -2187,59 +2637,60 @@ def register_model(model_cost: Union[str, dict]): } """ loaded_model_cost = {} - if isinstance(model_cost, dict): + if isinstance(model_cost, dict): loaded_model_cost = model_cost - elif isinstance(model_cost, str): + elif isinstance(model_cost, str): loaded_model_cost = litellm.get_model_cost_map(url=model_cost) for key, value in loaded_model_cost.items(): ## override / add new keys to the existing model cost dictionary if key in litellm.model_cost: - for k,v in loaded_model_cost[key].items(): + for k, v in loaded_model_cost[key].items(): litellm.model_cost[key][k] = v # add new model names to provider lists - if value.get('litellm_provider') == 'openai': + if value.get("litellm_provider") == "openai": if key not in litellm.open_ai_chat_completion_models: litellm.open_ai_chat_completion_models.append(key) - elif value.get('litellm_provider') == 'text-completion-openai': + elif value.get("litellm_provider") == "text-completion-openai": if key not in litellm.open_ai_text_completion_models: litellm.open_ai_text_completion_models.append(key) - elif value.get('litellm_provider') == 'cohere': + elif value.get("litellm_provider") == "cohere": if key not in litellm.cohere_models: litellm.cohere_models.append(key) - elif value.get('litellm_provider') == 'anthropic': + elif value.get("litellm_provider") == "anthropic": if key not in litellm.anthropic_models: litellm.anthropic_models.append(key) - elif value.get('litellm_provider') == 'openrouter': - split_string = key.split('/', 1) + elif value.get("litellm_provider") == "openrouter": + split_string = key.split("/", 1) if key not in litellm.openrouter_models: litellm.openrouter_models.append(split_string[1]) - elif value.get('litellm_provider') == 'vertex_ai-text-models': + elif value.get("litellm_provider") == "vertex_ai-text-models": if key not in litellm.vertex_text_models: litellm.vertex_text_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-code-text-models': + elif value.get("litellm_provider") == "vertex_ai-code-text-models": if key not in litellm.vertex_code_text_models: litellm.vertex_code_text_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-chat-models': + elif value.get("litellm_provider") == "vertex_ai-chat-models": if key not in litellm.vertex_chat_models: litellm.vertex_chat_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-code-chat-models': + elif value.get("litellm_provider") == "vertex_ai-code-chat-models": if key not in litellm.vertex_code_chat_models: litellm.vertex_code_chat_models.append(key) - elif value.get('litellm_provider') == 'ai21': + elif value.get("litellm_provider") == "ai21": if key not in litellm.ai21_models: litellm.ai21_models.append(key) - elif value.get('litellm_provider') == 'nlp_cloud': + elif value.get("litellm_provider") == "nlp_cloud": if key not in litellm.nlp_cloud_models: litellm.nlp_cloud_models.append(key) - elif value.get('litellm_provider') == 'aleph_alpha': + elif value.get("litellm_provider") == "aleph_alpha": if key not in litellm.aleph_alpha_models: litellm.aleph_alpha_models.append(key) - elif value.get('litellm_provider') == 'bedrock': + elif value.get("litellm_provider") == "bedrock": if key not in litellm.bedrock_models: litellm.bedrock_models.append(key) return model_cost + def get_litellm_params( api_key=None, force_timeout=600, @@ -2258,7 +2709,7 @@ def get_litellm_params( model_info=None, proxy_server_request=None, acompletion=None, - preset_cache_key = None + preset_cache_key=None, ): litellm_params = { "acompletion": acompletion, @@ -2275,20 +2726,21 @@ def get_litellm_params( "model_info": model_info, "proxy_server_request": proxy_server_request, "preset_cache_key": preset_cache_key, - "stream_response": {} # litellm_call_id: ModelResponse Dict + "stream_response": {}, # litellm_call_id: ModelResponse Dict } return litellm_params + def get_optional_params_image_gen( - n: Optional[int]=None, - quality: Optional[str]=None, - response_format: Optional[str]=None, - size: Optional[str]=None, - style: Optional[str]=None, - user: Optional[str]=None, - custom_llm_provider: Optional[str]=None, - **kwargs + n: Optional[int] = None, + quality: Optional[str] = None, + response_format: Optional[str] = None, + size: Optional[str] = None, + style: Optional[str] = None, + user: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, ): # retrieve all parameters passed to the function passed_params = locals() @@ -2296,38 +2748,44 @@ def get_optional_params_image_gen( special_params = passed_params.pop("kwargs") for k, v in special_params.items(): passed_params[k] = v - + default_params = { - "n": None, - "quality" : None, - "response_format" : None, - "size": None, + "n": None, + "quality": None, + "response_format": None, + "size": None, "style": None, "user": None, } - non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])} + non_default_params = { + k: v + for k, v in passed_params.items() + if (k in default_params and v != default_params[k]) + } ## raise exception if non-default value passed for non-openai/azure embedding calls if custom_llm_provider != "openai" and custom_llm_provider != "azure": - if len(non_default_params.keys()) > 0: - if litellm.drop_params is True: # drop the unsupported non-default values + if len(non_default_params.keys()) > 0: + if litellm.drop_params is True: # drop the unsupported non-default values keys = list(non_default_params.keys()) - for k in keys: + for k in keys: non_default_params.pop(k, None) return non_default_params - raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.") - + raise UnsupportedParamsError( + status_code=500, + message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", + ) + final_params = {**non_default_params, **kwargs} return final_params - - + def get_optional_params_embeddings( # 2 optional params - user=None, + user=None, encoding_format=None, custom_llm_provider="", - **kwargs + **kwargs, ): # retrieve all parameters passed to the function passed_params = locals() @@ -2335,26 +2793,31 @@ def get_optional_params_embeddings( special_params = passed_params.pop("kwargs") for k, v in special_params.items(): passed_params[k] = v - - default_params = { - "user": None, - "encoding_format": None - } - non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])} + default_params = {"user": None, "encoding_format": None} + + non_default_params = { + k: v + for k, v in passed_params.items() + if (k in default_params and v != default_params[k]) + } ## raise exception if non-default value passed for non-openai/azure embedding calls if custom_llm_provider != "openai" and custom_llm_provider != "azure": - if len(non_default_params.keys()) > 0: - if litellm.drop_params is True: # drop the unsupported non-default values + if len(non_default_params.keys()) > 0: + if litellm.drop_params is True: # drop the unsupported non-default values keys = list(non_default_params.keys()) - for k in keys: + for k in keys: non_default_params.pop(k, None) return non_default_params - raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.") - + raise UnsupportedParamsError( + status_code=500, + message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", + ) + final_params = {**non_default_params, **kwargs} return final_params + def get_optional_params( # use the openai defaults # 12 optional params functions=None, @@ -2376,7 +2839,7 @@ def get_optional_params( # use the openai defaults tools=None, tool_choice=None, max_retries=None, - **kwargs + **kwargs, ): # retrieve all parameters passed to the function passed_params = locals() @@ -2386,18 +2849,18 @@ def get_optional_params( # use the openai defaults default_params = { "functions": None, "function_call": None, - "temperature":None, - "top_p":None, - "n":None, - "stream":None, - "stop":None, - "max_tokens":None, - "presence_penalty":None, - "frequency_penalty":None, + "temperature": None, + "top_p": None, + "n": None, + "stream": None, + "stop": None, + "max_tokens": None, + "presence_penalty": None, + "frequency_penalty": None, "logit_bias": None, - "user":None, - "model":None, - "custom_llm_provider":"", + "user": None, + "model": None, + "custom_llm_provider": "", "response_format": None, "seed": None, "tools": None, @@ -2405,64 +2868,108 @@ def get_optional_params( # use the openai defaults "max_retries": None, } # filter out those parameters that were passed with non-default values - non_default_params = {k: v for k, v in passed_params.items() if (k != "model" and k != "custom_llm_provider" and k in default_params and v != default_params[k])} + non_default_params = { + k: v + for k, v in passed_params.items() + if ( + k != "model" + and k != "custom_llm_provider" + and k in default_params + and v != default_params[k] + ) + } optional_params = {} ## raise exception if function calling passed in for a provider that doesn't support it - if "functions" in non_default_params or "function_call" in non_default_params or "tools" in non_default_params: - if custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure": - if custom_llm_provider == "ollama": - # ollama actually supports json output + if ( + "functions" in non_default_params + or "function_call" in non_default_params + or "tools" in non_default_params + ): + if ( + custom_llm_provider != "openai" + and custom_llm_provider != "text-completion-openai" + and custom_llm_provider != "azure" + ): + if custom_llm_provider == "ollama": + # ollama actually supports json output optional_params["format"] = "json" - litellm.add_function_to_prompt = True # so that main.py adds the function call to the prompt + litellm.add_function_to_prompt = ( + True # so that main.py adds the function call to the prompt + ) if "tools" in non_default_params: - optional_params["functions_unsupported_model"] = non_default_params.pop("tools") - non_default_params.pop("tool_choice", None) # causes ollama requests to hang + optional_params[ + "functions_unsupported_model" + ] = non_default_params.pop("tools") + non_default_params.pop( + "tool_choice", None + ) # causes ollama requests to hang elif "functions" in non_default_params: - optional_params["functions_unsupported_model"] = non_default_params.pop("functions") - elif custom_llm_provider == "anyscale" and model == "mistralai/Mistral-7B-Instruct-v0.1": # anyscale just supports function calling with mistral + optional_params[ + "functions_unsupported_model" + ] = non_default_params.pop("functions") + elif ( + custom_llm_provider == "anyscale" + and model == "mistralai/Mistral-7B-Instruct-v0.1" + ): # anyscale just supports function calling with mistral pass - elif litellm.add_function_to_prompt: # if user opts to add it to prompt instead - optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions")) - else: - raise UnsupportedParamsError(status_code=500, message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.") + elif ( + litellm.add_function_to_prompt + ): # if user opts to add it to prompt instead + optional_params["functions_unsupported_model"] = non_default_params.pop( + "tools", non_default_params.pop("functions") + ) + else: + raise UnsupportedParamsError( + status_code=500, + message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.", + ) - def _check_valid_arg(supported_params): - print_verbose(f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}") + def _check_valid_arg(supported_params): + print_verbose( + f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}" + ) print_verbose(f"\nLiteLLM: Params passed to completion() {passed_params}") - print_verbose(f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}") + print_verbose( + f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}" + ) unsupported_params = {} for k in non_default_params.keys(): if k not in supported_params: - if k == "n" and n == 1: # langchain sends n=1 as a default value - continue # skip this param - if k == "max_retries": # TODO: This is a patch. We support max retries for OpenAI, Azure. For non OpenAI LLMs we need to add support for max retries - continue # skip this param + if k == "n" and n == 1: # langchain sends n=1 as a default value + continue # skip this param + if ( + k == "max_retries" + ): # TODO: This is a patch. We support max retries for OpenAI, Azure. For non OpenAI LLMs we need to add support for max retries + continue # skip this param # Always keeps this in elif code blocks - else: + else: unsupported_params[k] = non_default_params[k] if unsupported_params and not litellm.drop_params: - raise UnsupportedParamsError(status_code=500, message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.") - + raise UnsupportedParamsError( + status_code=500, + message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.", + ) + def _map_and_modify_arg(supported_params: dict, provider: str, model: str): """ filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`. """ filtered_stop = None - if "stop" in supported_params and litellm.drop_params: - if provider == "bedrock" and "amazon" in model: + if "stop" in supported_params and litellm.drop_params: + if provider == "bedrock" and "amazon" in model: filtered_stop = [] - if isinstance(stop, list): - for s in stop: - if re.match(r'^(\|+|User:)$', s): - filtered_stop.append(s) - if filtered_stop is not None: + if isinstance(stop, list): + for s in stop: + if re.match(r"^(\|+|User:)$", s): + filtered_stop.append(s) + if filtered_stop is not None: supported_params["stop"] = filtered_stop return supported_params - ## raise exception if provider doesn't support passed in param + ## raise exception if provider doesn't support passed in param if custom_llm_provider == "anthropic": - ## check if unsupported param passed in + ## check if unsupported param passed in supported_params = ["stream", "stop", "temperature", "top_p", "max_tokens"] _check_valid_arg(supported_params=supported_params) # handle anthropic params @@ -2470,7 +2977,7 @@ def get_optional_params( # use the openai defaults optional_params["stream"] = stream if stop is not None: if type(stop) == str: - stop = [stop] # openai can accept str/list for stop + stop = [stop] # openai can accept str/list for stop optional_params["stop_sequences"] = stop if temperature is not None: optional_params["temperature"] = temperature @@ -2479,8 +2986,18 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_tokens_to_sample"] = max_tokens elif custom_llm_provider == "cohere": - ## check if unsupported param passed in - supported_params = ["stream", "temperature", "max_tokens", "logit_bias", "top_p", "frequency_penalty", "presence_penalty", "stop", "n"] + ## check if unsupported param passed in + supported_params = [ + "stream", + "temperature", + "max_tokens", + "logit_bias", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "n", + ] _check_valid_arg(supported_params=supported_params) # handle cohere params if stream: @@ -2502,8 +3019,15 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "maritalk": - ## check if unsupported param passed in - supported_params = ["stream", "temperature", "max_tokens", "top_p", "presence_penalty", "stop"] + ## check if unsupported param passed in + supported_params = [ + "stream", + "temperature", + "max_tokens", + "top_p", + "presence_penalty", + "stop", + ] _check_valid_arg(supported_params=supported_params) # handle cohere params if stream: @@ -2521,17 +3045,24 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stopping_tokens"] = stop elif custom_llm_provider == "replicate": - ## check if unsupported param passed in - supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"] + ## check if unsupported param passed in + supported_params = [ + "stream", + "temperature", + "max_tokens", + "top_p", + "stop", + "seed", + ] _check_valid_arg(supported_params=supported_params) - + if stream: optional_params["stream"] = stream return optional_params if max_tokens is not None: if "vicuna" in model or "flan" in model: optional_params["max_length"] = max_tokens - elif "meta/codellama-13b" in model: + elif "meta/codellama-13b" in model: optional_params["max_tokens"] = max_tokens else: optional_params["max_new_tokens"] = max_tokens @@ -2542,7 +3073,7 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "huggingface": - ## check if unsupported param passed in + ## check if unsupported param passed in supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] _check_valid_arg(supported_params=supported_params) # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None @@ -2556,7 +3087,9 @@ def get_optional_params( # use the openai defaults optional_params["top_p"] = top_p if n is not None: optional_params["best_of"] = n - optional_params["do_sample"] = True # Need to sample if you want best of for hf inference endpoints + optional_params[ + "do_sample" + ] = True # Need to sample if you want best of for hf inference endpoints if stream is not None: optional_params["stream"] = stream if stop is not None: @@ -2567,7 +3100,7 @@ def get_optional_params( # use the openai defaults if max_tokens == 0: max_tokens = 1 optional_params["max_new_tokens"] = max_tokens - if n is not None: + if n is not None: optional_params["best_of"] = n if presence_penalty is not None: optional_params["repetition_penalty"] = presence_penalty @@ -2575,12 +3108,21 @@ def get_optional_params( # use the openai defaults # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False optional_params["decoder_input_details"] = special_params["echo"] - passed_params.pop("echo", None) # since we handle translating echo, we should not send it to TGI request + passed_params.pop( + "echo", None + ) # since we handle translating echo, we should not send it to TGI request elif custom_llm_provider == "together_ai": - ## check if unsupported param passed in - supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "frequency_penalty"] + ## check if unsupported param passed in + supported_params = [ + "stream", + "temperature", + "max_tokens", + "top_p", + "stop", + "frequency_penalty", + ] _check_valid_arg(supported_params=supported_params) - + if stream: optional_params["stream_tokens"] = stream if temperature is not None: @@ -2590,12 +3132,23 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_tokens"] = max_tokens if frequency_penalty is not None: - optional_params["repetition_penalty"] = frequency_penalty # https://docs.together.ai/reference/inference + optional_params[ + "repetition_penalty" + ] = frequency_penalty # https://docs.together.ai/reference/inference if stop is not None: - optional_params["stop"] = stop + optional_params["stop"] = stop elif custom_llm_provider == "ai21": - ## check if unsupported param passed in - supported_params = ["stream", "n", "temperature", "max_tokens", "top_p", "stop", "frequency_penalty", "presence_penalty"] + ## check if unsupported param passed in + supported_params = [ + "stream", + "n", + "temperature", + "max_tokens", + "top_p", + "stop", + "frequency_penalty", + "presence_penalty", + ] _check_valid_arg(supported_params=supported_params) if stream: @@ -2614,11 +3167,13 @@ def get_optional_params( # use the openai defaults optional_params["frequencyPenalty"] = {"scale": frequency_penalty} if presence_penalty is not None: optional_params["presencePenalty"] = {"scale": presence_penalty} - elif custom_llm_provider == "palm": # https://developers.generativeai.google/tutorials/curl_quickstart - ## check if unsupported param passed in + elif ( + custom_llm_provider == "palm" + ): # https://developers.generativeai.google/tutorials/curl_quickstart + ## check if unsupported param passed in supported_params = ["temperature", "top_p", "stream", "n", "stop", "max_tokens"] _check_valid_arg(supported_params=supported_params) - + if temperature is not None: optional_params["temperature"] = temperature if top_p is not None: @@ -2631,13 +3186,11 @@ def get_optional_params( # use the openai defaults optional_params["stop_sequences"] = stop if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens - elif ( - custom_llm_provider == "vertex_ai" - ): - ## check if unsupported param passed in + elif custom_llm_provider == "vertex_ai": + ## check if unsupported param passed in supported_params = ["temperature", "top_p", "max_tokens", "stream"] _check_valid_arg(supported_params=supported_params) - + if temperature is not None: optional_params["temperature"] = temperature if top_p is not None: @@ -2647,7 +3200,7 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens elif custom_llm_provider == "sagemaker": - ## check if unsupported param passed in + ## check if unsupported param passed in supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] _check_valid_arg(supported_params=supported_params) # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None @@ -2661,7 +3214,9 @@ def get_optional_params( # use the openai defaults optional_params["top_p"] = top_p if n is not None: optional_params["best_of"] = n - optional_params["do_sample"] = True # Need to sample if you want best of for hf inference endpoints + optional_params[ + "do_sample" + ] = True # Need to sample if you want best of for hf inference endpoints if stream is not None: optional_params["stream"] = stream if stop is not None: @@ -2684,7 +3239,7 @@ def get_optional_params( # use the openai defaults optional_params["temperature"] = temperature if top_p is not None: optional_params["topP"] = top_p - if stream: + if stream: optional_params["stream"] = stream elif "anthropic" in model: supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"] @@ -2699,9 +3254,9 @@ def get_optional_params( # use the openai defaults optional_params["top_p"] = top_p if stop is not None: optional_params["stop_sequences"] = stop - if stream: + if stream: optional_params["stream"] = stream - elif "amazon" in model: # amazon titan llms + elif "amazon" in model: # amazon titan llms supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"] _check_valid_arg(supported_params=supported_params) # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large @@ -2710,13 +3265,15 @@ def get_optional_params( # use the openai defaults if temperature is not None: optional_params["temperature"] = temperature if stop is not None: - filtered_stop = _map_and_modify_arg({"stop": stop}, provider="bedrock", model=model) + filtered_stop = _map_and_modify_arg( + {"stop": stop}, provider="bedrock", model=model + ) optional_params["stopSequences"] = filtered_stop["stop"] if top_p is not None: optional_params["topP"] = top_p - if stream: + if stream: optional_params["stream"] = stream - elif "meta" in model: # amazon / meta llms + elif "meta" in model: # amazon / meta llms supported_params = ["max_tokens", "temperature", "top_p", "stream"] _check_valid_arg(supported_params=supported_params) # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large @@ -2726,9 +3283,9 @@ def get_optional_params( # use the openai defaults optional_params["temperature"] = temperature if top_p is not None: optional_params["top_p"] = top_p - if stream: + if stream: optional_params["stream"] = stream - elif "cohere" in model: # cohere models on bedrock + elif "cohere" in model: # cohere models on bedrock supported_params = ["stream", "temperature", "max_tokens"] _check_valid_arg(supported_params=supported_params) # handle cohere params @@ -2739,7 +3296,16 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_tokens"] = max_tokens elif custom_llm_provider == "aleph_alpha": - supported_params = ["max_tokens", "stream", "top_p", "temperature", "presence_penalty", "frequency_penalty", "n", "stop"] + supported_params = [ + "max_tokens", + "stream", + "top_p", + "temperature", + "presence_penalty", + "frequency_penalty", + "n", + "stop", + ] _check_valid_arg(supported_params=supported_params) if max_tokens is not None: optional_params["maximum_tokens"] = max_tokens @@ -2758,9 +3324,16 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "ollama": - supported_params = ["max_tokens", "stream", "top_p", "temperature", "frequency_penalty", "stop"] + supported_params = [ + "max_tokens", + "stream", + "top_p", + "temperature", + "frequency_penalty", + "stop", + ] _check_valid_arg(supported_params=supported_params) - + if max_tokens is not None: optional_params["num_predict"] = max_tokens if stream: @@ -2774,7 +3347,16 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "nlp_cloud": - supported_params = ["max_tokens", "stream", "temperature", "top_p", "presence_penalty", "frequency_penalty", "n", "stop"] + supported_params = [ + "max_tokens", + "stream", + "temperature", + "top_p", + "presence_penalty", + "frequency_penalty", + "n", + "stop", + ] _check_valid_arg(supported_params=supported_params) if max_tokens is not None: @@ -2806,62 +3388,94 @@ def get_optional_params( # use the openai defaults if stream: optional_params["stream"] = stream elif custom_llm_provider == "deepinfra": - supported_params = ["temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user"] + supported_params = [ + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + ] _check_valid_arg(supported_params=supported_params) if temperature is not None: - if temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1": # this model does no support temperature == 0 - temperature = 0.0001 # close to 0 + if ( + temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1" + ): # this model does no support temperature == 0 + temperature = 0.0001 # close to 0 optional_params["temperature"] = temperature if top_p: optional_params["top_p"] = top_p - if n: + if n: optional_params["n"] = n - if stream: + if stream: optional_params["stream"] = stream - if stop: + if stop: optional_params["stop"] = stop - if max_tokens: + if max_tokens: optional_params["max_tokens"] = max_tokens - if presence_penalty: + if presence_penalty: optional_params["presence_penalty"] = presence_penalty - if frequency_penalty: + if frequency_penalty: optional_params["frequency_penalty"] = frequency_penalty - if logit_bias: + if logit_bias: optional_params["logit_bias"] = logit_bias - if user: + if user: optional_params["user"] = user elif custom_llm_provider == "perplexity": - supported_params = ["temperature", "top_p", "stream", "max_tokens", "presence_penalty", "frequency_penalty"] + supported_params = [ + "temperature", + "top_p", + "stream", + "max_tokens", + "presence_penalty", + "frequency_penalty", + ] _check_valid_arg(supported_params=supported_params) if temperature is not None: - if temperature == 0 and model == "mistral-7b-instruct": # this model does no support temperature == 0 - temperature = 0.0001 # close to 0 + if ( + temperature == 0 and model == "mistral-7b-instruct" + ): # this model does no support temperature == 0 + temperature = 0.0001 # close to 0 optional_params["temperature"] = temperature - if top_p: + if top_p: optional_params["top_p"] = top_p - if stream: + if stream: optional_params["stream"] = stream - if max_tokens: + if max_tokens: optional_params["max_tokens"] = max_tokens - if presence_penalty: + if presence_penalty: optional_params["presence_penalty"] = presence_penalty - if frequency_penalty: + if frequency_penalty: optional_params["frequency_penalty"] = frequency_penalty elif custom_llm_provider == "anyscale": - supported_params = ["temperature", "top_p", "stream", "max_tokens", "stop", "frequency_penalty", "presence_penalty"] + supported_params = [ + "temperature", + "top_p", + "stream", + "max_tokens", + "stop", + "frequency_penalty", + "presence_penalty", + ] if model == "mistralai/Mistral-7B-Instruct-v0.1": supported_params += ["functions", "function_call", "tools", "tool_choice"] _check_valid_arg(supported_params=supported_params) optional_params = non_default_params if temperature is not None: - if temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1": # this model does no support temperature == 0 - temperature = 0.0001 # close to 0 + if ( + temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1" + ): # this model does no support temperature == 0 + temperature = 0.0001 # close to 0 optional_params["temperature"] = temperature - if top_p: + if top_p: optional_params["top_p"] = top_p - if stream: + if stream: optional_params["stream"] = stream - if max_tokens: + if max_tokens: optional_params["max_tokens"] = max_tokens elif custom_llm_provider == "mistral": supported_params = ["temperature", "top_p", "stream", "max_tokens"] @@ -2869,13 +3483,13 @@ def get_optional_params( # use the openai defaults optional_params = non_default_params if temperature is not None: optional_params["temperature"] = temperature - if top_p is not None: + if top_p is not None: optional_params["top_p"] = top_p - if stream is not None: + if stream is not None: optional_params["stream"] = stream - if max_tokens is not None: + if max_tokens is not None: optional_params["max_tokens"] = max_tokens - + # check safe_mode, random_seed: https://docs.mistral.ai/api/#operation/createChatCompletion safe_mode = passed_params.pop("safe_mode", None) random_seed = passed_params.pop("random_seed", None) @@ -2884,9 +3498,29 @@ def get_optional_params( # use the openai defaults extra_body["safe_mode"] = safe_mode if random_seed is not None: extra_body["random_seed"] = random_seed - optional_params["extra_body"] = extra_body # openai client supports `extra_body` param + optional_params[ + "extra_body" + ] = extra_body # openai client supports `extra_body` param elif custom_llm_provider == "openrouter": - supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice", "max_retries"] + supported_params = [ + "functions", + "function_call", + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + ] _check_valid_arg(supported_params=supported_params) if functions is not None: @@ -2923,7 +3557,7 @@ def get_optional_params( # use the openai defaults optional_params["tool_choice"] = tool_choice if max_retries is not None: optional_params["max_retries"] = max_retries - + # OpenRouter-only parameters extra_body = {} transforms = passed_params.pop("transforms", None) @@ -2935,9 +3569,29 @@ def get_optional_params( # use the openai defaults extra_body["models"] = models if route is not None: extra_body["route"] = route - optional_params["extra_body"] = extra_body # openai client supports `extra_body` param + optional_params[ + "extra_body" + ] = extra_body # openai client supports `extra_body` param else: # assume passing in params for openai/azure openai - supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice", "max_retries"] + supported_params = [ + "functions", + "function_call", + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + ] _check_valid_arg(supported_params=supported_params) if functions is not None: optional_params["functions"] = functions @@ -2974,35 +3628,44 @@ def get_optional_params( # use the openai defaults if max_retries is not None: optional_params["max_retries"] = max_retries optional_params = non_default_params - # if user passed in non-default kwargs for specific providers/models, pass them along - for k in passed_params.keys(): - if k not in default_params.keys(): + # if user passed in non-default kwargs for specific providers/models, pass them along + for k in passed_params.keys(): + if k not in default_params.keys(): optional_params[k] = passed_params[k] return optional_params -def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_base: Optional[str] = None, api_key: Optional[str] = None): + +def get_llm_provider( + model: str, + custom_llm_provider: Optional[str] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, +): try: dynamic_api_key = None # check if llm provider provided - + if custom_llm_provider: return model, custom_llm_provider, dynamic_api_key, api_base - - if api_key and api_key.startswith("os.environ/"): + + if api_key and api_key.startswith("os.environ/"): dynamic_api_key = get_secret(api_key) # check if llm provider part of model name - if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list: + if ( + model.split("/", 1)[0] in litellm.provider_list + and model.split("/", 1)[0] not in litellm.model_list + ): custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] if custom_llm_provider == "perplexity": # perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai api_base = "https://api.perplexity.ai" dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY") - elif custom_llm_provider == "anyscale": + elif custom_llm_provider == "anyscale": # anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = "https://api.endpoints.anyscale.com/v1" dynamic_api_key = get_secret("ANYSCALE_API_KEY") - elif custom_llm_provider == "deepinfra": + elif custom_llm_provider == "deepinfra": # deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = "https://api.deepinfra.com/v1/openai" dynamic_api_key = get_secret("DEEPINFRA_API_KEY") @@ -3013,7 +3676,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ return model, custom_llm_provider, dynamic_api_key, api_base # check if api base is a known openai compatible endpoint - if api_base: + if api_base: for endpoint in litellm.openai_compatible_endpoints: if endpoint in api_base: if endpoint == "api.perplexity.ai": @@ -3032,20 +3695,26 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ # check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.) ## openai - chatcompletion + text completion - if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model or model in litellm.openai_image_generation_models: + if ( + model in litellm.open_ai_chat_completion_models + or "ft:gpt-3.5-turbo" in model + or model in litellm.openai_image_generation_models + ): custom_llm_provider = "openai" elif model in litellm.open_ai_text_completion_models: custom_llm_provider = "text-completion-openai" - ## anthropic + ## anthropic elif model in litellm.anthropic_models: custom_llm_provider = "anthropic" ## cohere elif model in litellm.cohere_models or model in litellm.cohere_embedding_models: custom_llm_provider = "cohere" ## replicate - elif model in litellm.replicate_models or (":" in model and len(model)>64): + elif model in litellm.replicate_models or (":" in model and len(model) > 64): model_parts = model.split(":") - if len(model_parts) > 1 and len(model_parts[1])==64: ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" + if ( + len(model_parts) > 1 and len(model_parts[1]) == 64 + ): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" custom_llm_provider = "replicate" elif model in litellm.replicate_models: custom_llm_provider = "replicate" @@ -3055,22 +3724,22 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ ## openrouter elif model in litellm.maritalk_models: custom_llm_provider = "maritalk" - ## vertex - text + chat + language (gemini) models - elif( - model in litellm.vertex_chat_models or - model in litellm.vertex_code_chat_models or - model in litellm.vertex_text_models or - model in litellm.vertex_code_text_models or - model in litellm.vertex_language_models + ## vertex - text + chat + language (gemini) models + elif ( + model in litellm.vertex_chat_models + or model in litellm.vertex_code_chat_models + or model in litellm.vertex_text_models + or model in litellm.vertex_code_text_models + or model in litellm.vertex_language_models ): custom_llm_provider = "vertex_ai" - ## ai21 + ## ai21 elif model in litellm.ai21_models: custom_llm_provider = "ai21" - ## aleph_alpha + ## aleph_alpha elif model in litellm.aleph_alpha_models: custom_llm_provider = "aleph_alpha" - ## baseten + ## baseten elif model in litellm.baseten_models: custom_llm_provider = "baseten" ## nlp_cloud @@ -3080,107 +3749,81 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ elif model in litellm.petals_models: custom_llm_provider = "petals" ## bedrock - elif model in litellm.bedrock_models or model in litellm.bedrock_embedding_models: + elif ( + model in litellm.bedrock_models or model in litellm.bedrock_embedding_models + ): custom_llm_provider = "bedrock" # openai embeddings elif model in litellm.open_ai_embedding_models: custom_llm_provider = "openai" - if custom_llm_provider is None or custom_llm_provider=="": - print() # noqa - print("\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m") # noqa - print() # noqa + if custom_llm_provider is None or custom_llm_provider == "": + if litellm.suppress_debug_info == False: + print() # noqa + print( # noqa + "\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m" # noqa + ) # noqa + print() # noqa error_str = f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model={model}\n Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers" # maps to openai.NotFoundError, this is raised when openai does not recognize the llm - raise litellm.exceptions.NotFoundError( # type: ignore + raise litellm.exceptions.NotFoundError( # type: ignore message=error_str, model=model, response=httpx.Response( status_code=404, - content=error_str, - request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm") # type: ignore + content=error_str, + request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore ), - llm_provider="" + llm_provider="", ) return model, custom_llm_provider, dynamic_api_key, api_base - except Exception as e: + except Exception as e: raise e def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): - api_key = (dynamic_api_key or litellm.api_key) - # openai + api_key = dynamic_api_key or litellm.api_key + # openai if llm_provider == "openai" or llm_provider == "text-completion-openai": - api_key = ( - api_key or - litellm.openai_key or - get_secret("OPENAI_API_KEY") - ) - # anthropic + api_key = api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") + # anthropic elif llm_provider == "anthropic": - api_key = ( - api_key or - litellm.anthropic_key or - get_secret("ANTHROPIC_API_KEY") - ) - # ai21 + api_key = api_key or litellm.anthropic_key or get_secret("ANTHROPIC_API_KEY") + # ai21 elif llm_provider == "ai21": - api_key = ( - api_key or - litellm.ai21_key or - get_secret("AI211_API_KEY") - ) - # aleph_alpha + api_key = api_key or litellm.ai21_key or get_secret("AI211_API_KEY") + # aleph_alpha elif llm_provider == "aleph_alpha": api_key = ( - api_key or - litellm.aleph_alpha_key or - get_secret("ALEPH_ALPHA_API_KEY") + api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") ) - # baseten + # baseten elif llm_provider == "baseten": - api_key = ( - api_key or - litellm.baseten_key or - get_secret("BASETEN_API_KEY") - ) - # cohere + api_key = api_key or litellm.baseten_key or get_secret("BASETEN_API_KEY") + # cohere elif llm_provider == "cohere": - api_key = ( - api_key or - litellm.cohere_key or - get_secret("COHERE_API_KEY") - ) - # huggingface + api_key = api_key or litellm.cohere_key or get_secret("COHERE_API_KEY") + # huggingface elif llm_provider == "huggingface": api_key = ( - api_key or - litellm.huggingface_key or - get_secret("HUGGINGFACE_API_KEY") + api_key or litellm.huggingface_key or get_secret("HUGGINGFACE_API_KEY") ) - # nlp_cloud + # nlp_cloud elif llm_provider == "nlp_cloud": - api_key = ( - api_key or - litellm.nlp_cloud_key or - get_secret("NLP_CLOUD_API_KEY") - ) - # replicate + api_key = api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") + # replicate elif llm_provider == "replicate": - api_key = ( - api_key or - litellm.replicate_key or - get_secret("REPLICATE_API_KEY") - ) - # together_ai + api_key = api_key or litellm.replicate_key or get_secret("REPLICATE_API_KEY") + # together_ai elif llm_provider == "together_ai": api_key = ( - api_key or - litellm.togetherai_api_key or - get_secret("TOGETHERAI_API_KEY") or - get_secret("TOGETHER_AI_TOKEN") + api_key + or litellm.togetherai_api_key + or get_secret("TOGETHERAI_API_KEY") + or get_secret("TOGETHER_AI_TOKEN") ) return api_key + def get_max_tokens(model: str): """ Get the maximum number of tokens allowed for a given model. @@ -3198,6 +3841,7 @@ def get_max_tokens(model: str): >>> get_max_tokens("gpt-4") 8192 """ + def _get_max_position_embeddings(model_name): # Construct the URL for the config.json file config_url = f"https://huggingface.co/{model_name}/raw/main/config.json" @@ -3223,19 +3867,21 @@ def get_max_tokens(model: str): try: if model in litellm.model_cost: return litellm.model_cost[model]["max_tokens"] - model, custom_llm_provider, _, _ = get_llm_provider(model=model) - if custom_llm_provider == "huggingface": + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + if custom_llm_provider == "huggingface": max_tokens = _get_max_position_embeddings(model_name=model) return max_tokens - else: + else: raise Exception() except: - raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json") + raise Exception( + "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" + ) def get_model_info(model: str): """ - Get a dict for the maximum tokens (context window), + Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model. Parameters: @@ -3262,6 +3908,7 @@ def get_model_info(model: str): "mode": "chat" } """ + def _get_max_position_embeddings(model_name): # Construct the URL for the config.json file config_url = f"https://huggingface.co/{model_name}/raw/main/config.json" @@ -3283,30 +3930,34 @@ def get_model_info(model: str): return None except requests.exceptions.RequestException as e: return None + try: azure_llms = { "gpt-35-turbo": "azure/gpt-3.5-turbo", "gpt-35-turbo-16k": "azure/gpt-3.5-turbo-16k", - "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct" + "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct", } - if model in azure_llms: + if model in azure_llms: model = azure_llms[model] if model in litellm.model_cost: return litellm.model_cost[model] - model, custom_llm_provider, _, _ = get_llm_provider(model=model) - if custom_llm_provider == "huggingface": + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + if custom_llm_provider == "huggingface": max_tokens = _get_max_position_embeddings(model_name=model) return { "max_tokens": max_tokens, "input_cost_per_token": 0, "output_cost_per_token": 0, "litellm_provider": "huggingface", - "mode": "chat" + "mode": "chat", } - else: + else: raise Exception() except: - raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json") + raise Exception( + "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" + ) + def json_schema_type(python_type_name: str): """Converts standard python types to json schema types @@ -3333,6 +3984,7 @@ def json_schema_type(python_type_name: str): return python_to_json_schema_types.get(python_type_name, "string") + def function_to_dict(input_function): # noqa: C901 """Using type hints and numpy-styled docstring, produce a dictionnary usable for OpenAI function calling @@ -3422,6 +4074,7 @@ def function_to_dict(input_function): # noqa: C901 return result + def load_test_model( model: str, custom_llm_provider: str = "", @@ -3464,13 +4117,14 @@ def load_test_model( "exception": e, } -def validate_environment(model: Optional[str]=None) -> dict: + +def validate_environment(model: Optional[str] = None) -> dict: """ Checks if the environment variables are valid for the given model. - + Args: model (Optional[str]): The name of the model. Defaults to None. - + Returns: dict: A dictionary containing the following keys: - keys_in_environment (bool): True if all the required keys are present in the environment, False otherwise. @@ -3480,7 +4134,10 @@ def validate_environment(model: Optional[str]=None) -> dict: missing_keys: List[str] = [] if model is None: - return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} + return { + "keys_in_environment": keys_in_environment, + "missing_keys": missing_keys, + } ## EXTRACT LLM PROVIDER - if model name provided try: custom_llm_provider = get_llm_provider(model=model) @@ -3491,7 +4148,7 @@ def validate_environment(model: Optional[str]=None) -> dict: # custom_llm_provider = model.split("/", 1)[0] # model = model.split("/", 1)[1] # custom_llm_provider_passed_in = True - + if custom_llm_provider: if custom_llm_provider == "openai": if "OPENAI_API_KEY" in os.environ: @@ -3499,12 +4156,16 @@ def validate_environment(model: Optional[str]=None) -> dict: else: missing_keys.append("OPENAI_API_KEY") elif custom_llm_provider == "azure": - if ("AZURE_API_BASE" in os.environ + if ( + "AZURE_API_BASE" in os.environ and "AZURE_API_VERSION" in os.environ - and "AZURE_API_KEY" in os.environ): + and "AZURE_API_KEY" in os.environ + ): keys_in_environment = True else: - missing_keys.extend(["AZURE_API_BASE", "AZURE_API_VERSION", "AZURE_API_KEY"]) + missing_keys.extend( + ["AZURE_API_BASE", "AZURE_API_VERSION", "AZURE_API_KEY"] + ) elif custom_llm_provider == "anthropic": if "ANTHROPIC_API_KEY" in os.environ: keys_in_environment = True @@ -3526,8 +4187,7 @@ def validate_environment(model: Optional[str]=None) -> dict: else: missing_keys.append("OPENROUTER_API_KEY") elif custom_llm_provider == "vertex_ai": - if ("VERTEXAI_PROJECT" in os.environ - and "VERTEXAI_LOCATION" in os.environ): + if "VERTEXAI_PROJECT" in os.environ and "VERTEXAI_LOCATION" in os.environ: keys_in_environment = True else: missing_keys.extend(["VERTEXAI_PROJECT", "VERTEXAI_PROJECT"]) @@ -3561,20 +4221,26 @@ def validate_environment(model: Optional[str]=None) -> dict: keys_in_environment = True else: missing_keys.append("NLP_CLOUD_API_KEY") - elif custom_llm_provider == "bedrock": - if "AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ: + elif custom_llm_provider == "bedrock": + if ( + "AWS_ACCESS_KEY_ID" in os.environ + and "AWS_SECRET_ACCESS_KEY" in os.environ + ): keys_in_environment = True else: missing_keys.append("AWS_ACCESS_KEY_ID") missing_keys.append("AWS_SECRET_ACCESS_KEY") else: ## openai - chatcompletion + text completion - if model in litellm.open_ai_chat_completion_models or litellm.open_ai_text_completion_models: + if ( + model in litellm.open_ai_chat_completion_models + or litellm.open_ai_text_completion_models + ): if "OPENAI_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("OPENAI_API_KEY") - ## anthropic + ## anthropic elif model in litellm.anthropic_models: if "ANTHROPIC_API_KEY" in os.environ: keys_in_environment = True @@ -3600,36 +4266,35 @@ def validate_environment(model: Optional[str]=None) -> dict: missing_keys.append("OPENROUTER_API_KEY") ## vertex - text + chat models elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models: - if ("VERTEXAI_PROJECT" in os.environ - and "VERTEXAI_LOCATION" in os.environ): + if "VERTEXAI_PROJECT" in os.environ and "VERTEXAI_LOCATION" in os.environ: keys_in_environment = True else: missing_keys.extend(["VERTEXAI_PROJECT", "VERTEXAI_PROJECT"]) - ## huggingface + ## huggingface elif model in litellm.huggingface_models: if "HUGGINGFACE_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("HUGGINGFACE_API_KEY") - ## ai21 + ## ai21 elif model in litellm.ai21_models: if "AI21_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("AI21_API_KEY") - ## together_ai + ## together_ai elif model in litellm.together_ai_models: if "TOGETHERAI_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("TOGETHERAI_API_KEY") - ## aleph_alpha + ## aleph_alpha elif model in litellm.aleph_alpha_models: if "ALEPH_ALPHA_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("ALEPH_ALPHA_API_KEY") - ## baseten + ## baseten elif model in litellm.baseten_models: if "BASETEN_API_KEY" in os.environ: keys_in_environment = True @@ -3641,7 +4306,8 @@ def validate_environment(model: Optional[str]=None) -> dict: keys_in_environment = True else: missing_keys.append("NLP_CLOUD_API_KEY") - return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} + return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} + def set_callbacks(callback_list, function_id=None): global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger @@ -3735,6 +4401,7 @@ def set_callbacks(callback_list, function_id=None): except Exception as e: raise e + # NOTE: DEPRECATING this in favor of using failure_handler() in Logging: def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs): global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger @@ -3878,7 +4545,8 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k exception_logging(logger_fn=user_logger_fn, exception=e) pass -async def convert_to_streaming_response_async(response_object: Optional[dict]=None): + +async def convert_to_streaming_response_async(response_object: Optional[dict] = None): """ Asynchronously converts a response object to a streaming response. @@ -3909,7 +4577,7 @@ async def convert_to_streaming_response_async(response_object: Optional[dict]=No content=choice["message"].get("content", None), role=choice["message"]["role"], function_call=choice["message"].get("function_call", None), - tool_calls=choice["message"].get("tool_calls", None) + tool_calls=choice["message"].get("tool_calls", None), ) finish_reason = choice.get("finish_reason", None) @@ -3925,10 +4593,9 @@ async def convert_to_streaming_response_async(response_object: Optional[dict]=No model_response_object.usage = Usage( completion_tokens=response_object["usage"].get("completion_tokens", 0), prompt_tokens=response_object["usage"].get("prompt_tokens", 0), - total_tokens=response_object["usage"].get("total_tokens", 0) + total_tokens=response_object["usage"].get("total_tokens", 0), ) - if "id" in response_object: model_response_object.id = response_object["id"] @@ -3941,19 +4608,20 @@ async def convert_to_streaming_response_async(response_object: Optional[dict]=No yield model_response_object await asyncio.sleep(0) -def convert_to_streaming_response(response_object: Optional[dict]=None): + +def convert_to_streaming_response(response_object: Optional[dict] = None): # used for yielding Cache hits when stream == True if response_object is None: raise Exception("Error in response object format") model_response_object = ModelResponse(stream=True) - choice_list=[] - for idx, choice in enumerate(response_object["choices"]): + choice_list = [] + for idx, choice in enumerate(response_object["choices"]): delta = Delta( - content=choice["message"].get("content", None), - role=choice["message"]["role"], - function_call=choice["message"].get("function_call", None), - tool_calls=choice["message"].get("tool_calls", None) + content=choice["message"].get("content", None), + role=choice["message"]["role"], + function_call=choice["message"].get("function_call", None), + tool_calls=choice["message"].get("tool_calls", None), ) finish_reason = choice.get("finish_reason", None) if finish_reason == None: @@ -3964,100 +4632,118 @@ def convert_to_streaming_response(response_object: Optional[dict]=None): model_response_object.choices = choice_list if "usage" in response_object and response_object["usage"] is not None: - model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore - model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore - model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore - if "id" in response_object: + if "id" in response_object: model_response_object.id = response_object["id"] if "system_fingerprint" in response_object: model_response_object.system_fingerprint = response_object["system_fingerprint"] - if "model" in response_object: + if "model" in response_object: model_response_object.model = response_object["model"] yield model_response_object -def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse, ImageResponse]]=None, response_type: Literal["completion", "embedding", "image_generation"] = "completion", stream = False): - try: - if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)): - if response_object is None or model_response_object is None: - raise Exception("Error in response object format") - if stream == True: - # for returning cached responses, we need to yield a generator - return convert_to_streaming_response( - response_object=response_object - ) - choice_list=[] - for idx, choice in enumerate(response_object["choices"]): - message = Message( - content=choice["message"].get("content", None), - role=choice["message"]["role"], - function_call=choice["message"].get("function_call", None), - tool_calls=choice["message"].get("tool_calls", None) - ) - finish_reason = choice.get("finish_reason", None) - if finish_reason == None: - # gpt-4 vision can return 'finish_reason' or 'finish_details' - finish_reason = choice.get("finish_details") - choice = Choices(finish_reason=finish_reason, index=idx, message=message) - choice_list.append(choice) - model_response_object.choices = choice_list +def convert_to_model_response_object( + response_object: Optional[dict] = None, + model_response_object: Optional[ + Union[ModelResponse, EmbeddingResponse, ImageResponse] + ] = None, + response_type: Literal[ + "completion", "embedding", "image_generation" + ] = "completion", + stream=False, +): + try: + if response_type == "completion" and ( + model_response_object is None + or isinstance(model_response_object, ModelResponse) + ): + if response_object is None or model_response_object is None: + raise Exception("Error in response object format") + if stream == True: + # for returning cached responses, we need to yield a generator + return convert_to_streaming_response(response_object=response_object) + choice_list = [] + for idx, choice in enumerate(response_object["choices"]): + message = Message( + content=choice["message"].get("content", None), + role=choice["message"]["role"], + function_call=choice["message"].get("function_call", None), + tool_calls=choice["message"].get("tool_calls", None), + ) + finish_reason = choice.get("finish_reason", None) + if finish_reason == None: + # gpt-4 vision can return 'finish_reason' or 'finish_details' + finish_reason = choice.get("finish_details") + choice = Choices( + finish_reason=finish_reason, index=idx, message=message + ) + choice_list.append(choice) + model_response_object.choices = choice_list - if "usage" in response_object and response_object["usage"] is not None: - model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore - model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore - model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + if "usage" in response_object and response_object["usage"] is not None: + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore - if "id" in response_object: - model_response_object.id = response_object["id"] - - if "system_fingerprint" in response_object: - model_response_object.system_fingerprint = response_object["system_fingerprint"] + if "id" in response_object: + model_response_object.id = response_object["id"] - if "model" in response_object: - model_response_object.model = response_object["model"] - return model_response_object - elif response_type == "embedding" and (model_response_object is None or isinstance(model_response_object, EmbeddingResponse)): - if response_object is None: - raise Exception("Error in response object format") - - if model_response_object is None: - model_response_object = EmbeddingResponse() + if "system_fingerprint" in response_object: + model_response_object.system_fingerprint = response_object[ + "system_fingerprint" + ] - if "model" in response_object: - model_response_object.model = response_object["model"] - - if "object" in response_object: - model_response_object.object = response_object["object"] + if "model" in response_object: + model_response_object.model = response_object["model"] + return model_response_object + elif response_type == "embedding" and ( + model_response_object is None + or isinstance(model_response_object, EmbeddingResponse) + ): + if response_object is None: + raise Exception("Error in response object format") - + if model_response_object is None: + model_response_object = EmbeddingResponse() + + if "model" in response_object: + model_response_object.model = response_object["model"] + + if "object" in response_object: + model_response_object.object = response_object["object"] + + model_response_object.data = response_object["data"] + + if "usage" in response_object and response_object["usage"] is not None: + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + + return model_response_object + elif response_type == "image_generation" and ( + model_response_object is None + or isinstance(model_response_object, ImageResponse) + ): + if response_object is None: + raise Exception("Error in response object format") + + if model_response_object is None: + model_response_object = ImageResponse() + + if "created" in response_object: + model_response_object.created = response_object["created"] + + if "data" in response_object: model_response_object.data = response_object["data"] - if "usage" in response_object and response_object["usage"] is not None: - model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore - model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore - model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore - - - return model_response_object - elif response_type == "image_generation" and (model_response_object is None or isinstance(model_response_object, ImageResponse)): - if response_object is None: - raise Exception("Error in response object format") - - if model_response_object is None: - model_response_object = ImageResponse() - - if "created" in response_object: - model_response_object.created = response_object["created"] - - if "data" in response_object: - model_response_object.data = response_object["data"] - - return model_response_object - except Exception as e: - raise Exception(f"Invalid response object {e}") + return model_response_object + except Exception as e: + raise Exception(f"Invalid response object {e}") # NOTE: DEPRECATING this in favor of using success_handler() in Logging: @@ -4163,6 +4849,7 @@ def valid_model(model): except: raise BadRequestError(message="", model=model, llm_provider="") + def check_valid_key(model: str, api_key: str): """ Checks if a given API key is valid for a specific model by making a litellm.completion call with max_tokens=10 @@ -4176,16 +4863,19 @@ def check_valid_key(model: str, api_key: str): """ messages = [{"role": "user", "content": "Hey, how's it going?"}] try: - litellm.completion(model=model, messages=messages, api_key=api_key, max_tokens=10) + litellm.completion( + model=model, messages=messages, api_key=api_key, max_tokens=10 + ) return True except AuthenticationError as e: return False except Exception as e: return False -def _should_retry(status_code: int): + +def _should_retry(status_code: int): """ - Reimplementation of openai's should retry logic, since that one can't be imported. + Reimplementation of openai's should retry logic, since that one can't be imported. https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L639 """ # If the server explicitly says whether or not to retry, obey. @@ -4207,13 +4897,20 @@ def _should_retry(status_code: int): return False -def _calculate_retry_after(remaining_retries: int, max_retries: int, response_headers: Optional[httpx.Headers]=None, min_timeout: int = 0): + +def _calculate_retry_after( + remaining_retries: int, + max_retries: int, + response_headers: Optional[httpx.Headers] = None, + min_timeout: int = 0, +): """ Reimplementation of openai's calculate retry after, since that one can't be imported. https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L631 """ try: - import email # openai import + import email # openai import + # About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After # # ". See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax for @@ -4234,7 +4931,7 @@ def _calculate_retry_after(remaining_retries: int, max_retries: int, response_he except Exception: retry_after = -1 - + # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. if 0 < retry_after <= 60: return retry_after @@ -4251,6 +4948,7 @@ def _calculate_retry_after(remaining_retries: int, max_retries: int, response_he timeout = sleep_seconds * jitter return timeout if timeout >= min_timeout else min_timeout + # integration helper function def modify_integration(integration_name, integration_params): global supabaseClient @@ -4260,7 +4958,12 @@ def modify_integration(integration_name, integration_params): # custom prompt helper function -def register_prompt_template(model: str, roles: dict, initial_prompt_value: str = "", final_prompt_value: str = ""): +def register_prompt_template( + model: str, + roles: dict, + initial_prompt_value: str = "", + final_prompt_value: str = "", +): """ Register a prompt template to follow your custom format for a given model @@ -4274,19 +4977,19 @@ def register_prompt_template(model: str, roles: dict, initial_prompt_value: str dict: The updated custom prompt dictionary. Example usage: ``` - import litellm + import litellm litellm.register_prompt_template( - model="llama-2", + model="llama-2", initial_prompt_value="You are a good assistant" # [OPTIONAL] - roles={ + roles={ "system": { "pre_message": "[INST] <>\n", # [OPTIONAL] "post_message": "\n<>\n [/INST]\n" # [OPTIONAL] }, - "user": { + "user": { "pre_message": "[INST] ", # [OPTIONAL] "post_message": " [/INST]" # [OPTIONAL] - }, + }, "assistant": { "pre_message": "\n" # [OPTIONAL] "post_message": "\n" # [OPTIONAL] @@ -4300,11 +5003,12 @@ def register_prompt_template(model: str, roles: dict, initial_prompt_value: str litellm.custom_prompt_dict[model] = { "roles": roles, "initial_prompt_value": initial_prompt_value, - "final_prompt_value": final_prompt_value + "final_prompt_value": final_prompt_value, } return litellm.custom_prompt_dict -####### DEPRECATED ################ + +####### DEPRECATED ################ def get_all_keys(llm_provider=None): @@ -4393,20 +5097,25 @@ def get_model_list(): f"[Non-Blocking Error] get_model_list error - {traceback.format_exc()}" ) + ####### EXCEPTION MAPPING ################ def exception_type( - model, - original_exception, - custom_llm_provider, - completion_kwargs={}, - ): + model, + original_exception, + custom_llm_provider, + completion_kwargs={}, +): global user_logger_fn, liteDebuggerClient exception_mapping_worked = False if litellm.suppress_debug_info is False: - print() # noqa - print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m") # noqa - print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.") # noqa - print() # noqa + print() # noqa + print( # noqa + "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" # noqa + ) # noqa + print( # noqa + "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa + ) # noqa + print() # noqa try: if model: error_str = str(original_exception) @@ -4414,39 +5123,53 @@ def exception_type( exception_type = type(original_exception).__name__ else: exception_type = "" - - if "Request Timeout Error" in error_str or "Request timed out" in error_str: + + if "Request Timeout Error" in error_str or "Request timed out" in error_str: exception_mapping_worked = True raise Timeout( message=f"APITimeoutError - Request timed out", model=model, - llm_provider=custom_llm_provider + llm_provider=custom_llm_provider, ) - if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai" or custom_llm_provider in litellm.openai_compatible_providers: - if "This model's maximum context length is" in error_str or "Request too large" in error_str: + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "text-completion-openai" + or custom_llm_provider == "custom_openai" + or custom_llm_provider in litellm.openai_compatible_providers + ): + if ( + "This model's maximum context length is" in error_str + or "Request too large" in error_str + ): exception_mapping_worked = True raise ContextWindowExceededError( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response + response=original_exception.response, ) - elif "invalid_request_error" in error_str and "model_not_found" in error_str: + elif ( + "invalid_request_error" in error_str + and "model_not_found" in error_str + ): exception_mapping_worked = True raise NotFoundError( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response + response=original_exception.response, ) - elif "invalid_request_error" in error_str and "Incorrect API key provided" not in error_str: + elif ( + "invalid_request_error" in error_str + and "Incorrect API key provided" not in error_str + ): exception_mapping_worked = True raise BadRequestError( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response + response=original_exception.response, ) elif hasattr(original_exception, "status_code"): exception_mapping_worked = True @@ -4456,7 +5179,7 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 404: exception_mapping_worked = True @@ -4464,7 +5187,7 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -4479,7 +5202,7 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -4487,17 +5210,17 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 503: + elif original_exception.status_code == 503: exception_mapping_worked = True raise ServiceUnavailableError( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 504: # gateway timeout error + elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True raise Timeout( message=f"OpenAIException - {original_exception.message}", @@ -4507,11 +5230,11 @@ def exception_type( else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - request=original_exception.request + request=original_exception.request, ) else: # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors @@ -4519,25 +5242,28 @@ def exception_type( __cause__=original_exception.__cause__, llm_provider=custom_llm_provider, model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "anthropic": # one of the anthropics if hasattr(original_exception, "message"): - if "prompt is too long" in original_exception.message or "prompt: length" in original_exception.message: + if ( + "prompt is too long" in original_exception.message + or "prompt: length" in original_exception.message + ): exception_mapping_worked = True raise ContextWindowExceededError( - message=original_exception.message, + message=original_exception.message, model=model, llm_provider="anthropic", - response=original_exception.response + response=original_exception.response, ) if "Invalid API Key" in original_exception.message: exception_mapping_worked = True raise AuthenticationError( - message=original_exception.message, + message=original_exception.message, model=model, llm_provider="anthropic", - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): print_verbose(f"status_code: {original_exception.status_code}") @@ -4547,15 +5273,18 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 400 or original_exception.status_code == 413: + elif ( + original_exception.status_code == 400 + or original_exception.status_code == 413 + ): exception_mapping_worked = True raise BadRequestError( message=f"AnthropicException - {original_exception.message}", model=model, llm_provider="anthropic", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -4563,7 +5292,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", model=model, llm_provider="anthropic", - request=original_exception.request + request=original_exception.request, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -4571,7 +5300,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -4579,7 +5308,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - response=original_exception.response + response=original_exception.response, ) else: exception_mapping_worked = True @@ -4588,7 +5317,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "replicate": if "Incorrect authentication token" in error_str: @@ -4597,7 +5326,7 @@ def exception_type( message=f"ReplicateException - {error_str}", llm_provider="replicate", model=model, - response=original_exception.response + response=original_exception.response, ) elif "input is too long" in error_str: exception_mapping_worked = True @@ -4605,7 +5334,7 @@ def exception_type( message=f"ReplicateException - {error_str}", model=model, llm_provider="replicate", - response=original_exception.response + response=original_exception.response, ) elif exception_type == "ModelError": exception_mapping_worked = True @@ -4613,7 +5342,7 @@ def exception_type( message=f"ReplicateException - {error_str}", model=model, llm_provider="replicate", - response=original_exception.response + response=original_exception.response, ) elif "Request was throttled" in error_str: exception_mapping_worked = True @@ -4621,7 +5350,7 @@ def exception_type( message=f"ReplicateException - {error_str}", llm_provider="replicate", model=model, - response=original_exception.response + response=original_exception.response, ) elif hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -4630,15 +5359,19 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 400 or original_exception.status_code == 422 or original_exception.status_code == 413: + elif ( + original_exception.status_code == 400 + or original_exception.status_code == 422 + or original_exception.status_code == 413 + ): exception_mapping_worked = True raise BadRequestError( message=f"ReplicateException - {original_exception.message}", model=model, llm_provider="replicate", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -4646,7 +5379,7 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", model=model, llm_provider="replicate", - request=original_exception.request + request=original_exception.request, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -4654,7 +5387,7 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -4662,48 +5395,60 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response + response=original_exception.response, ) exception_mapping_worked = True raise APIError( - status_code=500, + status_code=500, message=f"ReplicateException - {str(original_exception)}", llm_provider="replicate", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "bedrock": - if "too many tokens" in error_str or "expected maxLength:" in error_str or "Input is too long" in error_str or "Too many input tokens" in error_str: + if ( + "too many tokens" in error_str + or "expected maxLength:" in error_str + or "Input is too long" in error_str + or "Too many input tokens" in error_str + ): exception_mapping_worked = True raise ContextWindowExceededError( message=f"BedrockException: Context Window Error - {error_str}", - model=model, + model=model, llm_provider="bedrock", - response=original_exception.response + response=original_exception.response, ) if "Malformed input request" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"BedrockException - {error_str}", - model=model, + message=f"BedrockException - {error_str}", + model=model, llm_provider="bedrock", - response=original_exception.response + response=original_exception.response, ) - if "Unable to locate credentials" in error_str or "The security token included in the request is invalid" in error_str: + if ( + "Unable to locate credentials" in error_str + or "The security token included in the request is invalid" + in error_str + ): exception_mapping_worked = True raise AuthenticationError( - message=f"BedrockException Invalid Authentication - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response + message=f"BedrockException Invalid Authentication - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, ) - if "throttlingException" in error_str or "ThrottlingException" in error_str: + if ( + "throttlingException" in error_str + or "ThrottlingException" in error_str + ): exception_mapping_worked = True raise RateLimitError( - message=f"BedrockException: Rate Limit Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response + message=f"BedrockException: Rate Limit Error - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 500: @@ -4712,7 +5457,7 @@ def exception_type( message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 401: exception_mapping_worked = True @@ -4720,49 +5465,55 @@ def exception_type( message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response + response=original_exception.response, ) - elif custom_llm_provider == "sagemaker": + elif custom_llm_provider == "sagemaker": if "Unable to locate credentials" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"SagemakerException - {error_str}", - model=model, + message=f"SagemakerException - {error_str}", + model=model, llm_provider="sagemaker", - response=original_exception.response + response=original_exception.response, ) - elif "Input validation error: `best_of` must be > 0 and <= 2" in error_str: + elif ( + "Input validation error: `best_of` must be > 0 and <= 2" + in error_str + ): exception_mapping_worked = True raise BadRequestError( - message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", - model=model, + message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", + model=model, llm_provider="sagemaker", - response=original_exception.response + response=original_exception.response, ) elif custom_llm_provider == "vertex_ai": - if "Vertex AI API has not been used in project" in error_str or "Unable to find your project" in error_str: + if ( + "Vertex AI API has not been used in project" in error_str + or "Unable to find your project" in error_str + ): exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, + message=f"VertexAIException - {error_str}", + model=model, llm_provider="vertex_ai", - response=original_exception.response + response=original_exception.response, ) - elif "403" in error_str: + elif "403" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, + message=f"VertexAIException - {error_str}", + model=model, llm_provider="vertex_ai", - response=original_exception.response + response=original_exception.response, ) elif "The response was blocked." in error_str: exception_mapping_worked = True raise UnprocessableEntityError( - message=f"VertexAIException - {error_str}", - model=model, + message=f"VertexAIException - {error_str}", + model=model, llm_provider="vertex_ai", - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 400: @@ -4771,16 +5522,16 @@ def exception_type( message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", - response=original_exception.response + response=original_exception.response, ) - if original_exception.status_code == 500: + if original_exception.status_code == 500: exception_mapping_worked = True raise APIError( message=f"VertexAIException - {error_str}", status_code=500, model=model, llm_provider="vertex_ai", - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "palm": if "503 Getting metadata" in error_str: @@ -4788,10 +5539,10 @@ def exception_type( # 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate. exception_mapping_worked = True raise BadRequestError( - message=f"PalmException - Invalid api key", - model=model, + message=f"PalmException - Invalid api key", + model=model, llm_provider="palm", - response=original_exception.response + response=original_exception.response, ) if "400 Request payload size exceeds" in error_str: exception_mapping_worked = True @@ -4799,7 +5550,7 @@ def exception_type( message=f"PalmException - {error_str}", model=model, llm_provider="palm", - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 400: @@ -4808,7 +5559,7 @@ def exception_type( message=f"PalmException - {error_str}", model=model, llm_provider="palm", - response=original_exception.response + response=original_exception.response, ) # Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes elif custom_llm_provider == "cohere": # Cohere @@ -4821,7 +5572,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) elif "too many tokens" in error_str: exception_mapping_worked = True @@ -4829,16 +5580,19 @@ def exception_type( message=f"CohereException - {original_exception.message}", model=model, llm_provider="cohere", - response=original_exception.response + response=original_exception.response, ) elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 400 or original_exception.status_code == 498: + if ( + original_exception.status_code == 400 + or original_exception.status_code == 498 + ): exception_mapping_worked = True raise BadRequestError( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -4846,7 +5600,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) elif ( "CohereConnectionError" in exception_type @@ -4856,7 +5610,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) elif "invalid type:" in error_str: exception_mapping_worked = True @@ -4864,7 +5618,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) elif "Unexpected server error" in error_str: exception_mapping_worked = True @@ -4872,17 +5626,17 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) else: if hasattr(original_exception, "status_code"): exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - request=original_exception.request + request=original_exception.request, ) raise original_exception elif custom_llm_provider == "huggingface": @@ -4892,15 +5646,15 @@ def exception_type( message=error_str, model=model, llm_provider="huggingface", - response=original_exception.response + response=original_exception.response, ) elif "A valid user token is required" in error_str: exception_mapping_worked = True raise BadRequestError( - message=error_str, + message=error_str, llm_provider="huggingface", model=model, - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -4909,7 +5663,7 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 400: exception_mapping_worked = True @@ -4917,7 +5671,7 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", model=model, llm_provider="huggingface", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -4925,7 +5679,7 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", model=model, llm_provider="huggingface", - request=original_exception.request + request=original_exception.request, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -4933,16 +5687,16 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - response=original_exception.response + response=original_exception.response, ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "ai21": if hasattr(original_exception, "message"): @@ -4952,15 +5706,15 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response + response=original_exception.response, ) - if "Bad or missing API token." in original_exception.message: + if "Bad or missing API token." in original_exception.message: exception_mapping_worked = True raise BadRequestError( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -4969,7 +5723,7 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", llm_provider="ai21", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -4977,7 +5731,7 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - request=original_exception.request + request=original_exception.request, ) if original_exception.status_code == 422: exception_mapping_worked = True @@ -4985,7 +5739,7 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -4993,16 +5747,16 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", llm_provider="ai21", model=model, - response=original_exception.response + response=original_exception.response, ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"AI21Exception - {original_exception.message}", llm_provider="ai21", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "nlp_cloud": if "detail" in error_str: @@ -5012,7 +5766,7 @@ def exception_type( message=f"NLPCloudException - {error_str}", model=model, llm_provider="nlp_cloud", - response=original_exception.response + response=original_exception.response, ) elif "value is not a valid" in error_str: exception_mapping_worked = True @@ -5020,148 +5774,188 @@ def exception_type( message=f"NLPCloudException - {error_str}", model=model, llm_provider="nlp_cloud", - response=original_exception.response + response=original_exception.response, ) - else: + else: exception_mapping_worked = True raise APIError( status_code=500, message=f"NLPCloudException - {error_str}", model=model, llm_provider="nlp_cloud", - request=original_exception.request + request=original_exception.request, ) - if hasattr(original_exception, "status_code"): # https://docs.nlpcloud.com/?shell#errors - if original_exception.status_code == 400 or original_exception.status_code == 406 or original_exception.status_code == 413 or original_exception.status_code == 422: + if hasattr( + original_exception, "status_code" + ): # https://docs.nlpcloud.com/?shell#errors + if ( + original_exception.status_code == 400 + or original_exception.status_code == 406 + or original_exception.status_code == 413 + or original_exception.status_code == 422 + ): exception_mapping_worked = True raise BadRequestError( message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 401 or original_exception.status_code == 403: + elif ( + original_exception.status_code == 401 + or original_exception.status_code == 403 + ): exception_mapping_worked = True raise AuthenticationError( message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 522 or original_exception.status_code == 524: + elif ( + original_exception.status_code == 522 + or original_exception.status_code == 524 + ): exception_mapping_worked = True raise Timeout( message=f"NLPCloudException - {original_exception.message}", model=model, llm_provider="nlp_cloud", - request=original_exception.request + request=original_exception.request, ) - elif original_exception.status_code == 429 or original_exception.status_code == 402: + elif ( + original_exception.status_code == 429 + or original_exception.status_code == 402 + ): exception_mapping_worked = True raise RateLimitError( message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 500 or original_exception.status_code == 503: + elif ( + original_exception.status_code == 500 + or original_exception.status_code == 503 + ): exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - request=original_exception.request + request=original_exception.request, ) - elif original_exception.status_code == 504 or original_exception.status_code == 520: + elif ( + original_exception.status_code == 504 + or original_exception.status_code == 520 + ): exception_mapping_worked = True raise ServiceUnavailableError( message=f"NLPCloudException - {original_exception.message}", model=model, llm_provider="nlp_cloud", - response=original_exception.response + response=original_exception.response, ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "together_ai": import json + try: error_response = json.loads(error_str) except: error_response = {"error": error_str} - if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]: + if ( + "error" in error_response + and "`inputs` tokens + `max_new_tokens` must be <=" + in error_response["error"] + ): exception_mapping_worked = True raise ContextWindowExceededError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response + response=original_exception.response, ) - elif "error" in error_response and "invalid private key" in error_response["error"]: + elif ( + "error" in error_response + and "invalid private key" in error_response["error"] + ): exception_mapping_worked = True raise AuthenticationError( message=f"TogetherAIException - {error_response['error']}", llm_provider="together_ai", model=model, - response=original_exception.response + response=original_exception.response, ) - elif "error" in error_response and "INVALID_ARGUMENT" in error_response["error"]: + elif ( + "error" in error_response + and "INVALID_ARGUMENT" in error_response["error"] + ): exception_mapping_worked = True raise BadRequestError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response + response=original_exception.response, ) - - elif "error" in error_response and "API key doesn't match expected format." in error_response["error"]: + + elif ( + "error" in error_response + and "API key doesn't match expected format." + in error_response["error"] + ): exception_mapping_worked = True raise BadRequestError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response + response=original_exception.response, ) - elif "error_type" in error_response and error_response["error_type"] == "validation": + elif ( + "error_type" in error_response + and error_response["error_type"] == "validation" + ): exception_mapping_worked = True raise BadRequestError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"TogetherAIException - {original_exception.message}", - model=model, - llm_provider="together_ai", - request=original_exception.request - ) + exception_mapping_worked = True + raise Timeout( + message=f"TogetherAIException - {original_exception.message}", + model=model, + llm_provider="together_ai", + request=original_exception.request, + ) elif original_exception.status_code == 422: exception_mapping_worked = True raise BadRequestError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"TogetherAIException - {original_exception.message}", - llm_provider="together_ai", - model=model, - response=original_exception.response - ) + exception_mapping_worked = True + raise RateLimitError( + message=f"TogetherAIException - {original_exception.message}", + llm_provider="together_ai", + model=model, + response=original_exception.response, + ) elif original_exception.status_code == 524: exception_mapping_worked = True raise Timeout( @@ -5169,31 +5963,34 @@ def exception_type( llm_provider="together_ai", model=model, ) - else: + else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"TogetherAIException - {original_exception.message}", llm_provider="together_ai", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "aleph_alpha": - if "This is longer than the model's maximum context length" in error_str: + if ( + "This is longer than the model's maximum context length" + in error_str + ): exception_mapping_worked = True raise ContextWindowExceededError( message=f"AlephAlphaException - {original_exception.message}", - llm_provider="aleph_alpha", + llm_provider="aleph_alpha", model=model, - response=original_exception.response + response=original_exception.response, ) elif "InvalidToken" in error_str or "No token provided" in error_str: exception_mapping_worked = True raise BadRequestError( message=f"AlephAlphaException - {original_exception.message}", - llm_provider="aleph_alpha", + llm_provider="aleph_alpha", model=model, - response=original_exception.response + response=original_exception.response, ) elif hasattr(original_exception, "status_code"): print_verbose(f"status code: {original_exception.status_code}") @@ -5202,7 +5999,7 @@ def exception_type( raise AuthenticationError( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", - model=model + model=model, ) elif original_exception.status_code == 400: exception_mapping_worked = True @@ -5210,7 +6007,7 @@ def exception_type( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -5218,7 +6015,7 @@ def exception_type( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -5226,30 +6023,30 @@ def exception_type( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response + response=original_exception.response, ) raise original_exception raise original_exception elif custom_llm_provider == "ollama": if isinstance(original_exception, dict): error_str = original_exception.get("error", "") - else: + else: error_str = str(original_exception) if "no such file or directory" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}", - model=model, - llm_provider="ollama", - response=original_exception.response - ) - elif "Failed to establish a new connection" in error_str: + message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}", + model=model, + llm_provider="ollama", + response=original_exception.response, + ) + elif "Failed to establish a new connection" in error_str: exception_mapping_worked = True raise ServiceUnavailableError( message=f"OllamaException: {original_exception}", - llm_provider="ollama", + llm_provider="ollama", model=model, - response=original_exception.response + response=original_exception.response, ) elif "Invalid response object from API" in error_str: exception_mapping_worked = True @@ -5257,7 +6054,7 @@ def exception_type( message=f"OllamaException: {original_exception}", llm_provider="ollama", model=model, - response=original_exception.response + response=original_exception.response, ) elif custom_llm_provider == "vllm": if hasattr(original_exception, "status_code"): @@ -5267,16 +6064,16 @@ def exception_type( message=f"VLLMException - {original_exception.message}", llm_provider="vllm", model=model, - request=original_exception.request + request=original_exception.request, ) - elif custom_llm_provider == "azure": + elif custom_llm_provider == "azure": if "This model's maximum context length is" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response + response=original_exception.response, ) elif "DeploymentNotFound" in error_str: exception_mapping_worked = True @@ -5284,7 +6081,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response + response=original_exception.response, ) elif "invalid_request_error" in error_str: exception_mapping_worked = True @@ -5292,7 +6089,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response + response=original_exception.response, ) elif hasattr(original_exception, "status_code"): exception_mapping_worked = True @@ -5302,7 +6099,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -5310,7 +6107,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", - request=original_exception.request + request=original_exception.request, ) if original_exception.status_code == 422: exception_mapping_worked = True @@ -5318,7 +6115,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -5326,16 +6123,16 @@ def exception_type( message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", - response=original_exception.response + response=original_exception.response, ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - request=original_exception.request + request=original_exception.request, ) else: # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors @@ -5343,31 +6140,36 @@ def exception_type( __cause__=original_exception.__cause__, llm_provider="azure", model=model, - request=original_exception.request + request=original_exception.request, ) - if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk + if ( + "BadRequestError.__init__() missing 1 required positional argument: 'param'" + in str(original_exception) + ): # deal with edge-case invalid request error bug in openai-python sdk exception_mapping_worked = True raise BadRequestError( message=f"OpenAIException: This can happen due to missing AZURE_API_VERSION: {str(original_exception)}", - model=model, + model=model, llm_provider=custom_llm_provider, - response=original_exception.response + response=original_exception.response, ) - else: # ensure generic errors always return APIConnectionError= + else: # ensure generic errors always return APIConnectionError= exception_mapping_worked = True if hasattr(original_exception, "request"): raise APIConnectionError( message=f"{str(original_exception)}", llm_provider=custom_llm_provider, model=model, - request=original_exception.request + request=original_exception.request, ) - else: - raise APIConnectionError( + else: + raise APIConnectionError( message=f"{str(original_exception)}", llm_provider=custom_llm_provider, model=model, - request= httpx.Request(method="POST", url="https://api.openai.com/v1/") # stub the request + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), # stub the request ) except Exception as e: # LOGGING @@ -5401,9 +6203,10 @@ def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None): executor.submit(litellm_telemetry, data) # threading.Thread(target=litellm_telemetry, args=(data,), daemon=True).start() + def get_or_generate_uuid(): temp_dir = os.path.join(os.path.abspath(os.sep), "tmp") - uuid_file = os.path.join(temp_dir, "litellm_uuid.txt") + uuid_file = os.path.join(temp_dir, "litellm_uuid.txt") try: # Try to open the file and load the UUID with open(uuid_file, "r") as file: @@ -5415,19 +6218,19 @@ def get_or_generate_uuid(): except FileNotFoundError: # Generate a new UUID if the file doesn't exist or is empty - try: + try: new_uuid = uuid.uuid4() uuid_value = str(new_uuid) with open(uuid_file, "w") as file: file.write(uuid_value) - except: # if writing to tmp/litellm_uuid.txt then retry writing to litellm_uuid.txt + except: # if writing to tmp/litellm_uuid.txt then retry writing to litellm_uuid.txt try: new_uuid = uuid.uuid4() uuid_value = str(new_uuid) with open("litellm_uuid.txt", "w") as file: file.write(uuid_value) - except: # if this 3rd attempt fails just pass - # Good first issue for someone to improve this function :) + except: # if this 3rd attempt fails just pass + # Good first issue for someone to improve this function :) return except: # [Non-Blocking Error] @@ -5444,17 +6247,13 @@ def litellm_telemetry(data): uuid_value = str(uuid.uuid4()) try: # Prepare the data to send to litellm logging api - try: + try: pkg_version = importlib.metadata.version("litellm") except: pkg_version = None if "model" not in data: data["model"] = None - payload = { - "uuid": uuid_value, - "data": data, - "version:": pkg_version - } + payload = {"uuid": uuid_value, "data": data, "version:": pkg_version} # Make the POST request to litellm logging api response = requests.post( "https://litellm-logging.onrender.com/logging", @@ -5466,29 +6265,33 @@ def litellm_telemetry(data): # [Non-Blocking Error] return + ######### Secret Manager ############################ # checks if user has passed in a secret manager client # if passed in then checks the secret there -def get_secret(secret_name: str, default_value: Optional[str]=None): - if secret_name.startswith("os.environ/"): +def get_secret(secret_name: str, default_value: Optional[str] = None): + if secret_name.startswith("os.environ/"): secret_name = secret_name.replace("os.environ/", "") - try: + try: if litellm.secret_manager_client is not None: try: client = litellm.secret_manager_client - if type(client).__module__ + '.' + type(client).__name__ == 'azure.keyvault.secrets._client.SecretClient': # support Azure Secret Client - from azure.keyvault.secrets import SecretClient + if ( + type(client).__module__ + "." + type(client).__name__ + == "azure.keyvault.secrets._client.SecretClient" + ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient secret = retrieved_secret = client.get_secret(secret_name).value - else: # assume the default is infisicial client + else: # assume the default is infisicial client secret = client.get_secret(secret_name).secret_value - except: # check if it's in os.environ + except: # check if it's in os.environ secret = os.environ.get(secret_name) return secret else: return os.environ.get(secret_name) - except Exception as e: - if default_value is not None: + except Exception as e: + if default_value is not None: return default_value - else: + else: raise e @@ -5496,7 +6299,9 @@ def get_secret(secret_name: str, default_value: Optional[str]=None): # wraps the completion stream to return the correct format for the model # replicate/anthropic/cohere class CustomStreamWrapper: - def __init__(self, completion_stream, model, custom_llm_provider=None, logging_obj=None): + def __init__( + self, completion_stream, model, custom_llm_provider=None, logging_obj=None + ): self.model = model self.custom_llm_provider = custom_llm_provider self.logging_obj = logging_obj @@ -5504,7 +6309,7 @@ class CustomStreamWrapper: self.sent_first_chunk = False self.sent_last_chunk = False self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "", ""] - self.holding_chunk = "" + self.holding_chunk = "" self.complete_response = "" def __iter__(self): @@ -5513,94 +6318,113 @@ class CustomStreamWrapper: def __aiter__(self): return self - def process_chunk(self, chunk: str): + def process_chunk(self, chunk: str): """ NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta. """ - try: + try: chunk = chunk.strip() self.complete_response = self.complete_response.strip() - if chunk.startswith(self.complete_response): + if chunk.startswith(self.complete_response): # Remove last_sent_chunk only if it appears at the start of the new chunk - chunk = chunk[len(self.complete_response):] + chunk = chunk[len(self.complete_response) :] self.complete_response += chunk - return chunk - except Exception as e: + return chunk + except Exception as e: raise e - - def check_special_tokens(self, chunk: str, finish_reason: Optional[str]): + + def check_special_tokens(self, chunk: str, finish_reason: Optional[str]): hold = False - if finish_reason: - for token in self.special_tokens: + if finish_reason: + for token in self.special_tokens: if token in chunk: - chunk = chunk.replace(token, "") + chunk = chunk.replace(token, "") return hold, chunk - + if self.sent_first_chunk is True: return hold, chunk curr_chunk = self.holding_chunk + chunk curr_chunk = curr_chunk.strip() - for token in self.special_tokens: - if len(curr_chunk) < len(token) and curr_chunk in token: + for token in self.special_tokens: + if len(curr_chunk) < len(token) and curr_chunk in token: hold = True elif len(curr_chunk) >= len(token): if token in curr_chunk: self.holding_chunk = curr_chunk.replace(token, "") hold = True - else: + else: pass - - if hold is False: # reset - self.holding_chunk = "" - return hold, curr_chunk + if hold is False: # reset + self.holding_chunk = "" + return hold, curr_chunk def handle_anthropic_chunk(self, chunk): str_line = chunk.decode("utf-8") # Convert bytes to string - text = "" + text = "" is_finished = False finish_reason = None if str_line.startswith("data:"): data_json = json.loads(str_line[5:]) - text = data_json.get("completion", "") - if data_json.get("stop_reason", None): + text = data_json.get("completion", "") + if data_json.get("stop_reason", None): is_finished = True finish_reason = data_json["stop_reason"] - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif "error" in str_line: raise ValueError(f"Unable to parse response. Original response: {str_line}") else: - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } def handle_together_ai_chunk(self, chunk): chunk = chunk.decode("utf-8") - text = "" + text = "" is_finished = False finish_reason = None - if "text" in chunk: + if "text" in chunk: text_index = chunk.find('"text":"') # this checks if text: exists text_start = text_index + len('"text":"') text_end = chunk.find('"}', text_start) if text_index != -1 and text_end != -1: extracted_text = chunk[text_start:text_end] text = extracted_text - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif "[DONE]" in chunk: return {"text": text, "is_finished": True, "finish_reason": "stop"} elif "error" in chunk: - raise litellm.together_ai.TogetherAIError(status_code=422, message=f"{str(chunk)}") + raise litellm.together_ai.TogetherAIError( + status_code=422, message=f"{str(chunk)}" + ) else: - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } def handle_huggingface_chunk(self, chunk): try: if type(chunk) != str: - chunk = chunk.decode("utf-8") # DO NOT REMOVE this: This is required for HF inference API + Streaming - text = "" + chunk = chunk.decode( + "utf-8" + ) # DO NOT REMOVE this: This is required for HF inference API + Streaming + text = "" is_finished = False finish_reason = "" print_verbose(f"chunk: {chunk}") @@ -5609,52 +6433,72 @@ class CustomStreamWrapper: print_verbose(f"data json: {data_json}") if "token" in data_json and "text" in data_json["token"]: text = data_json["token"]["text"] - if data_json.get("details", False) and data_json["details"].get("finish_reason", False): + if data_json.get("details", False) and data_json["details"].get( + "finish_reason", False + ): is_finished = True finish_reason = data_json["details"]["finish_reason"] - elif data_json.get("generated_text", False): # if full generated text exists, then stream is complete - text = "" # don't return the final bos token + elif data_json.get( + "generated_text", False + ): # if full generated text exists, then stream is complete + text = "" # don't return the final bos token is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} - elif "error" in chunk: + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + elif "error" in chunk: raise ValueError(chunk) - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} - except Exception as e: + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + except Exception as e: traceback.print_exc() # raise(e) - - def handle_ai21_chunk(self, chunk): # fake streaming + + def handle_ai21_chunk(self, chunk): # fake streaming chunk = chunk.decode("utf-8") data_json = json.loads(chunk) try: text = data_json["completions"][0]["data"]["text"] is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - - def handle_maritalk_chunk(self, chunk): # fake streaming + + def handle_maritalk_chunk(self, chunk): # fake streaming chunk = chunk.decode("utf-8") data_json = json.loads(chunk) try: text = data_json["answer"] is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_nlp_cloud_chunk(self, chunk): - text = "" + text = "" is_finished = False finish_reason = "" try: if "dolphin" in self.model: chunk = self.process_chunk(chunk=chunk) - else: + else: data_json = json.loads(chunk) chunk = data_json["generated_text"] text = chunk @@ -5662,10 +6506,14 @@ class CustomStreamWrapper: text = text.replace("[DONE]", "") is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except Exception as e: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_aleph_alpha_chunk(self, chunk): chunk = chunk.decode("utf-8") data_json = json.loads(chunk) @@ -5673,28 +6521,36 @@ class CustomStreamWrapper: text = data_json["completions"][0]["completion"] is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_cohere_chunk(self, chunk): chunk = chunk.decode("utf-8") data_json = json.loads(chunk) try: - text = "" + text = "" is_finished = False finish_reason = "" - if "text" in data_json: + if "text" in data_json: text = data_json["text"] - elif "is_finished" in data_json: + elif "is_finished" in data_json: is_finished = data_json["is_finished"] finish_reason = data_json["finish_reason"] - else: + else: raise Exception(data_json) - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_azure_chunk(self, chunk): is_finished = False finish_reason = "" @@ -5704,72 +6560,92 @@ class CustomStreamWrapper: text = "" is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif chunk.startswith("data:"): - data_json = json.loads(chunk[5:]) # chunk.startswith("data:"): + data_json = json.loads(chunk[5:]) # chunk.startswith("data:"): try: - if len(data_json["choices"]) > 0: - text = data_json["choices"][0]["delta"].get("content", "") - if data_json["choices"][0].get("finish_reason", None): + if len(data_json["choices"]) > 0: + text = data_json["choices"][0]["delta"].get("content", "") + if data_json["choices"][0].get("finish_reason", None): is_finished = True finish_reason = data_json["choices"][0]["finish_reason"] - print_verbose(f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}") - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + print_verbose( + f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}" + ) + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: - raise ValueError(f"Unable to parse response. Original response: {chunk}") + raise ValueError( + f"Unable to parse response. Original response: {chunk}" + ) elif "error" in chunk: raise ValueError(f"Unable to parse response. Original response: {chunk}") else: - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } def handle_replicate_chunk(self, chunk): try: - text = "" + text = "" is_finished = False finish_reason = "" - if "output" in chunk: - text = chunk['output'] - if "status" in chunk: + if "output" in chunk: + text = chunk["output"] + if "status" in chunk: if chunk["status"] == "succeeded": is_finished = True finish_reason = "stop" - elif chunk.get("error", None): + elif chunk.get("error", None): raise Exception(chunk["error"]) - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - - def handle_openai_chat_completion_chunk(self, chunk): - try: + + def handle_openai_chat_completion_chunk(self, chunk): + try: print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n") str_line = chunk - text = "" + text = "" is_finished = False finish_reason = None - original_chunk = None # this is used for function/tool calling - if len(str_line.choices) > 0: + original_chunk = None # this is used for function/tool calling + if len(str_line.choices) > 0: if str_line.choices[0].delta.content is not None: text = str_line.choices[0].delta.content - else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai + else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai original_chunk = str_line if str_line.choices[0].finish_reason: is_finished = True finish_reason = str_line.choices[0].finish_reason return { - "text": text, - "is_finished": is_finished, + "text": text, + "is_finished": is_finished, "finish_reason": finish_reason, - "original_chunk": str_line + "original_chunk": str_line, } except Exception as e: traceback.print_exc() raise e def handle_openai_text_completion_chunk(self, chunk): - try: + try: str_line = chunk - text = "" + text = "" is_finished = False finish_reason = None print_verbose(f"str_line: {str_line}") @@ -5777,20 +6653,36 @@ class CustomStreamWrapper: text = "" is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif str_line.startswith("data:"): data_json = json.loads(str_line[5:]) print_verbose(f"delta content: {data_json}") - text = data_json["choices"][0].get("text", "") - if data_json["choices"][0].get("finish_reason", None): + text = data_json["choices"][0].get("text", "") + if data_json["choices"][0].get("finish_reason", None): is_finished = True finish_reason = data_json["choices"][0]["finish_reason"] - print_verbose(f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}") - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + print_verbose( + f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}" + ) + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif "error" in str_line: - raise ValueError(f"Unable to parse response. Original response: {str_line}") + raise ValueError( + f"Unable to parse response. Original response: {str_line}" + ) else: - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except Exception as e: traceback.print_exc() @@ -5808,14 +6700,22 @@ class CustomStreamWrapper: return "" data_json = json.loads(chunk) if "model_output" in data_json: - if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list): + if ( + isinstance(data_json["model_output"], dict) + and "data" in data_json["model_output"] + and isinstance(data_json["model_output"]["data"], list) + ): return data_json["model_output"]["data"][0] elif isinstance(data_json["model_output"], str): return data_json["model_output"] - elif "completion" in data_json and isinstance(data_json["completion"], str): + elif "completion" in data_json and isinstance( + data_json["completion"], str + ): return data_json["completion"] else: - raise ValueError(f"Unable to parse response. Original response: {chunk}") + raise ValueError( + f"Unable to parse response. Original response: {chunk}" + ) else: return "" else: @@ -5824,53 +6724,60 @@ class CustomStreamWrapper: traceback.print_exc() return "" - def handle_ollama_stream(self, chunk): - try: + def handle_ollama_stream(self, chunk): + try: if isinstance(chunk, dict): json_chunk = chunk else: json_chunk = json.loads(chunk) - if "error" in json_chunk: + if "error" in json_chunk: raise Exception(f"Ollama Error - {json_chunk}") - - text = "" + + text = "" is_finished = False finish_reason = None if json_chunk["done"] == True: text = "" is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif json_chunk["response"]: print_verbose(f"delta content: {json_chunk}") text = json_chunk["response"] - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} - else: + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + else: raise Exception(f"Ollama Error - {json_chunk}") - except Exception as e: + except Exception as e: raise e - def handle_bedrock_stream(self, chunk): if hasattr(chunk, "get"): - chunk = chunk.get('chunk') - chunk_data = json.loads(chunk.get('bytes').decode()) + chunk = chunk.get("chunk") + chunk_data = json.loads(chunk.get("bytes").decode()) else: chunk_data = json.loads(chunk.decode()) if chunk_data: - text = "" + text = "" is_finished = False finish_reason = "" - if "outputText" in chunk_data: - text = chunk_data['outputText'] + if "outputText" in chunk_data: + text = chunk_data["outputText"] # ai21 mapping - if "ai21" in self.model: # fake ai21 streaming - text = chunk_data.get('completions')[0].get('data').get('text') + if "ai21" in self.model: # fake ai21 streaming + text = chunk_data.get("completions")[0].get("data").get("text") is_finished = True finish_reason = "stop" # anthropic mapping - elif "completion" in chunk_data: - text = chunk_data['completion'] # bedrock.anthropic + elif "completion" in chunk_data: + text = chunk_data["completion"] # bedrock.anthropic stop_reason = chunk_data.get("stop_reason", None) if stop_reason != None: is_finished = True @@ -5878,22 +6785,26 @@ class CustomStreamWrapper: ######## bedrock.cohere mappings ############### # meta mapping elif "generation" in chunk_data: - text = chunk_data['generation'] # bedrock.meta + text = chunk_data["generation"] # bedrock.meta # cohere mapping elif "text" in chunk_data: - text = chunk_data["text"] # bedrock.cohere + text = chunk_data["text"] # bedrock.cohere # cohere mapping for finish reason elif "finish_reason" in chunk_data: finish_reason = chunk_data["finish_reason"] is_finished = True - elif chunk_data.get("completionReason", None): + elif chunk_data.get("completionReason", None): is_finished = True finish_reason = chunk_data["completionReason"] - elif chunk.get("error", None): + elif chunk.get("error", None): raise Exception(chunk["error"]) - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } return "" - + def chunk_creator(self, chunk): model_response = ModelResponse(stream=True, model=self.model) model_response.choices = [StreamingChoices()] @@ -5905,62 +6816,83 @@ class CustomStreamWrapper: if self.custom_llm_provider and self.custom_llm_provider == "anthropic": response_obj = self.handle_anthropic_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.model == "replicate" or self.custom_llm_provider == "replicate": response_obj = self.handle_replicate_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] - elif ( - self.custom_llm_provider and self.custom_llm_provider == "together_ai"): + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] + elif self.custom_llm_provider and self.custom_llm_provider == "together_ai": response_obj = self.handle_together_ai_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": response_obj = self.handle_huggingface_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] - elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] + elif ( + self.custom_llm_provider and self.custom_llm_provider == "baseten" + ): # baseten doesn't provide streaming completion_obj["content"] = self.handle_baseten_chunk(chunk) - elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming + elif ( + self.custom_llm_provider and self.custom_llm_provider == "ai21" + ): # ai21 doesn't provide streaming response_obj = self.handle_ai21_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": response_obj = self.handle_maritalk_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider and self.custom_llm_provider == "vllm": completion_obj["content"] = chunk[0].outputs[0].text - elif self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha": #aleph alpha doesn't provide streaming + elif ( + self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha" + ): # aleph alpha doesn't provide streaming response_obj = self.handle_aleph_alpha_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider == "nlp_cloud": - try: + try: response_obj = self.handle_nlp_cloud_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] except Exception as e: if self.sent_last_chunk: raise e else: - if self.sent_first_chunk is False: + if self.sent_first_chunk is False: raise Exception("An unknown error occurred with the stream") model_response.choices[0].finish_reason = "stop" self.sent_last_chunk = True elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai": try: # print(chunk) - if hasattr(chunk, 'text'): - # vertexAI chunks return + if hasattr(chunk, "text"): + # vertexAI chunks return # MultiCandidateTextGenerationResponse(text=' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', _prediction_response=Prediction(predictions=[{'candidates': [{'content': ' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', 'author': '1'}], 'citationMetadata': [{'citations': None}], 'safetyAttributes': [{'blocked': False, 'scores': None, 'categories': None}]}], deployed_model_id='', model_version_id=None, model_resource_name=None, explanations=None), is_blocked=False, safety_attributes={}, candidates=[ ```python # This Python code says "Hi" 100 times. # Create]) @@ -5968,28 +6900,32 @@ class CustomStreamWrapper: else: completion_obj["content"] = str(chunk) except StopIteration as e: - if self.sent_last_chunk: - raise e + if self.sent_last_chunk: + raise e else: model_response.choices[0].finish_reason = "stop" self.sent_last_chunk = True elif self.custom_llm_provider == "cohere": response_obj = self.handle_cohere_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider == "bedrock": - if self.sent_last_chunk: + if self.sent_last_chunk: raise StopIteration response_obj = self.handle_bedrock_stream(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] self.sent_last_chunk = True elif self.custom_llm_provider == "sagemaker": print_verbose(f"ENTERS SAGEMAKER STREAMING") - if len(self.completion_stream)==0: - if self.sent_last_chunk: + if len(self.completion_stream) == 0: + if self.sent_last_chunk: raise StopIteration else: model_response.choices[0].finish_reason = "stop" @@ -5997,10 +6933,12 @@ class CustomStreamWrapper: new_chunk = self.completion_stream print_verbose(f"sagemaker chunk: {new_chunk}") completion_obj["content"] = new_chunk - self.completion_stream = self.completion_stream[len(self.completion_stream):] + self.completion_stream = self.completion_stream[ + len(self.completion_stream) : + ] elif self.custom_llm_provider == "petals": - if len(self.completion_stream)==0: - if self.sent_last_chunk: + if len(self.completion_stream) == 0: + if self.sent_last_chunk: raise StopIteration else: model_response.choices[0].finish_reason = "stop" @@ -6013,8 +6951,8 @@ class CustomStreamWrapper: elif self.custom_llm_provider == "palm": # fake streaming response_obj = {} - if len(self.completion_stream)==0: - if self.sent_last_chunk: + if len(self.completion_stream) == 0: + if self.sent_last_chunk: raise StopIteration else: model_response.choices[0].finish_reason = "stop" @@ -6028,33 +6966,50 @@ class CustomStreamWrapper: response_obj = self.handle_ollama_stream(chunk) completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider == "text-completion-openai": response_obj = self.handle_openai_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] - else: # openai chat model + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] + else: # openai chat model response_obj = self.handle_openai_chat_completion_chunk(chunk) if response_obj == None: return completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] model_response.model = self.model - print_verbose(f"model_response: {model_response}; completion_obj: {completion_obj}") - print_verbose(f"model_response finish reason 3: {model_response.choices[0].finish_reason}") - if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string - hold, model_response_str = self.check_special_tokens(chunk=completion_obj["content"], finish_reason=model_response.choices[0].finish_reason) # filter out bos/eos tokens from openai-compatible hf endpoints - print_verbose(f"hold - {hold}, model_response_str - {model_response_str}") - if hold is False: - ## check if openai/azure chunk + print_verbose( + f"model_response: {model_response}; completion_obj: {completion_obj}" + ) + print_verbose( + f"model_response finish reason 3: {model_response.choices[0].finish_reason}" + ) + if ( + len(completion_obj["content"]) > 0 + ): # cannot set content of an OpenAI Object to be an empty string + hold, model_response_str = self.check_special_tokens( + chunk=completion_obj["content"], + finish_reason=model_response.choices[0].finish_reason, + ) # filter out bos/eos tokens from openai-compatible hf endpoints + print_verbose( + f"hold - {hold}, model_response_str - {model_response_str}" + ) + if hold is False: + ## check if openai/azure chunk original_chunk = response_obj.get("original_chunk", None) - if original_chunk: + if original_chunk: model_response.id = original_chunk.id if len(original_chunk.choices) > 0: try: @@ -6062,122 +7017,158 @@ class CustomStreamWrapper: model_response.choices[0].delta = Delta(**delta) except Exception as e: model_response.choices[0].delta = Delta() - else: - return - model_response.system_fingerprint = original_chunk.system_fingerprint + else: + return + model_response.system_fingerprint = ( + original_chunk.system_fingerprint + ) if self.sent_first_chunk == False: model_response.choices[0].delta["role"] = "assistant" self.sent_first_chunk = True - else: - ## else - completion_obj["content"] = model_response_str + else: + ## else + completion_obj["content"] = model_response_str if self.sent_first_chunk == False: completion_obj["role"] = "assistant" self.sent_first_chunk = True model_response.choices[0].delta = Delta(**completion_obj) print_verbose(f"model_response: {model_response}") return model_response - else: - return + else: + return elif model_response.choices[0].finish_reason: - # flush any remaining holding chunk + # flush any remaining holding chunk if len(self.holding_chunk) > 0: if model_response.choices[0].delta.content is None: model_response.choices[0].delta.content = self.holding_chunk else: - model_response.choices[0].delta.content = self.holding_chunk + model_response.choices[0].delta.content - self.holding_chunk = "" - model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai + model_response.choices[0].delta.content = ( + self.holding_chunk + model_response.choices[0].delta.content + ) + self.holding_chunk = "" + model_response.choices[0].finish_reason = map_finish_reason( + model_response.choices[0].finish_reason + ) # ensure consistent output to openai return model_response - elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints + elif ( + response_obj is not None + and response_obj.get("original_chunk", None) is not None + ): # function / tool calling branch - only set for openai/azure compatible endpoints # enter this branch when no content has been passed in response original_chunk = response_obj.get("original_chunk", None) model_response.id = original_chunk.id if len(original_chunk.choices) > 0: - if original_chunk.choices[0].delta.function_call is not None or original_chunk.choices[0].delta.tool_calls is not None: + if ( + original_chunk.choices[0].delta.function_call is not None + or original_chunk.choices[0].delta.tool_calls is not None + ): try: delta = dict(original_chunk.choices[0].delta) model_response.choices[0].delta = Delta(**delta) except Exception as e: model_response.choices[0].delta = Delta() - else: + else: return - else: + else: return model_response.system_fingerprint = original_chunk.system_fingerprint if self.sent_first_chunk == False: model_response.choices[0].delta["role"] = "assistant" self.sent_first_chunk = True return model_response - else: + else: return except StopIteration: raise StopIteration - except Exception as e: + except Exception as e: traceback_exception = traceback.format_exc() e.message = str(e) - raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e) + raise exception_type( + model=self.model, + custom_llm_provider=self.custom_llm_provider, + original_exception=e, + ) ## needs to handle the empty string case (even starting chunk can be an empty string) def __next__(self): try: while True: - if isinstance(self.completion_stream, str) or isinstance(self.completion_stream, bytes): + if isinstance(self.completion_stream, str) or isinstance( + self.completion_stream, bytes + ): chunk = self.completion_stream else: chunk = next(self.completion_stream) print_verbose(f"value of chunk: {chunk} ") - if chunk is not None and chunk != b'': + if chunk is not None and chunk != b"": print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") response = self.chunk_creator(chunk=chunk) print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}") - if response is None: + if response is None: continue ## LOGGING - threading.Thread(target=self.logging_obj.success_handler, args=(response,)).start() # log response + threading.Thread( + target=self.logging_obj.success_handler, args=(response,) + ).start() # log response return response except StopIteration: raise # Re-raise StopIteration except Exception as e: traceback_exception = traceback.format_exc() # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated - threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start() + threading.Thread( + target=self.logging_obj.failure_handler, args=(e, traceback_exception) + ).start() raise e - - async def __anext__(self): try: - if (self.custom_llm_provider == "openai" + if ( + self.custom_llm_provider == "openai" or self.custom_llm_provider == "azure" or self.custom_llm_provider == "custom_openai" or self.custom_llm_provider == "text-completion-openai" or self.custom_llm_provider == "huggingface" or self.custom_llm_provider == "ollama" - or self.custom_llm_provider == "vertex_ai"): + or self.custom_llm_provider == "vertex_ai" + ): print_verbose(f"INSIDE ASYNC STREAMING!!!") - print_verbose(f"value of async completion stream: {self.completion_stream}") + print_verbose( + f"value of async completion stream: {self.completion_stream}" + ) async for chunk in self.completion_stream: print_verbose(f"value of async chunk: {chunk}") if chunk == "None" or chunk is None: raise Exception - # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. + # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. # __anext__ also calls async_success_handler, which does logging print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}") - processed_chunk = self.chunk_creator(chunk=chunk) - print_verbose(f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}") - if processed_chunk is None: + processed_chunk = self.chunk_creator(chunk=chunk) + print_verbose( + f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}" + ) + if processed_chunk is None: continue ## LOGGING - threading.Thread(target=self.logging_obj.success_handler, args=(processed_chunk,)).start() # log response - asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,)) + threading.Thread( + target=self.logging_obj.success_handler, args=(processed_chunk,) + ).start() # log response + asyncio.create_task( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) return processed_chunk raise StopAsyncIteration - else: # temporary patch for non-aiohttp async calls + else: # temporary patch for non-aiohttp async calls # example - boto3 bedrock llms processed_chunk = next(self) - asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,)) + asyncio.create_task( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) return processed_chunk except StopAsyncIteration: raise @@ -6186,9 +7177,12 @@ class CustomStreamWrapper: except Exception as e: traceback_exception = traceback.format_exc() # Handle any exceptions that might occur during streaming - asyncio.create_task(self.logging_obj.async_failure_handler(e, traceback_exception)) + asyncio.create_task( + self.logging_obj.async_failure_handler(e, traceback_exception) + ) raise e + class TextCompletionStreamWrapper: def __init__(self, completion_stream, model): self.completion_stream = completion_stream @@ -6199,16 +7193,18 @@ class TextCompletionStreamWrapper: def __aiter__(self): return self - + def convert_to_text_completion_object(self, chunk: ModelResponse): - try: + try: response = TextCompletionResponse() response["id"] = chunk.get("id", None) response["object"] = "text_completion" response["created"] = response.get("created", None) response["model"] = response.get("model", None) text_choices = TextChoices() - if isinstance(chunk, Choices): # chunk should always be of type StreamingChoices + if isinstance( + chunk, Choices + ): # chunk should always be of type StreamingChoices raise Exception text_choices["text"] = chunk["choices"][0]["delta"]["content"] text_choices["index"] = response["choices"][0]["index"] @@ -6216,7 +7212,9 @@ class TextCompletionStreamWrapper: response["choices"] = [text_choices] return response except Exception as e: - raise Exception(f"Error occurred converting to text completion object - chunk: {chunk}; Error: {str(e)}") + raise Exception( + f"Error occurred converting to text completion object - chunk: {chunk}; Error: {str(e)}" + ) def __next__(self): # model_response = ModelResponse(stream=True, model=self.model) @@ -6224,32 +7222,34 @@ class TextCompletionStreamWrapper: try: for chunk in self.completion_stream: if chunk == "None" or chunk is None: - raise Exception - processed_chunk = self.convert_to_text_completion_object(chunk=chunk) + raise Exception + processed_chunk = self.convert_to_text_completion_object(chunk=chunk) return processed_chunk raise StopIteration except StopIteration: raise StopIteration - except Exception as e: - print(f"got exception {e}") # noqa + except Exception as e: + print(f"got exception {e}") # noqa async def __anext__(self): try: async for chunk in self.completion_stream: if chunk == "None" or chunk is None: - raise Exception - processed_chunk = self.convert_to_text_completion_object(chunk=chunk) + raise Exception + processed_chunk = self.convert_to_text_completion_object(chunk=chunk) return processed_chunk raise StopIteration except StopIteration: raise StopAsyncIteration + def mock_completion_streaming_obj(model_response, mock_response, model): for i in range(0, len(mock_response), 3): - completion_obj = {"role": "assistant", "content": mock_response[i: i+3]} + completion_obj = {"role": "assistant", "content": mock_response[i : i + 3]} model_response.choices[0].delta = completion_obj yield model_response + ########## Reading Config File ############################ def read_config_args(config_path) -> dict: try: @@ -6264,23 +7264,25 @@ def read_config_args(config_path) -> dict: except Exception as e: raise e + ########## experimental completion variants ############################ + def completion_with_config(config: Union[dict, str], **kwargs): """ - Generate a litellm.completion() using a config dict and all supported completion args + Generate a litellm.completion() using a config dict and all supported completion args Example config; config = { "default_fallback_models": # [Optional] List of model names to try if a call fails - "available_models": # [Optional] List of all possible models you could call + "available_models": # [Optional] List of all possible models you could call "adapt_to_prompt_size": # [Optional] True/False - if you want to select model based on prompt size (will pick from available_models) "model": { "model-name": { - "needs_moderation": # [Optional] True/False - if you want to call openai moderations endpoint before making completion call. Will raise exception, if flagged. + "needs_moderation": # [Optional] True/False - if you want to call openai moderations endpoint before making completion call. Will raise exception, if flagged. "error_handling": { "error-type": { # One of the errors listed here - https://docs.litellm.ai/docs/exception_mapping#custom-mapping-list - "fallback_model": "" # str, name of the model it should try instead, when that error occurs + "fallback_model": "" # str, name of the model it should try instead, when that error occurs } } } @@ -6304,11 +7306,11 @@ def completion_with_config(config: Union[dict, str], **kwargs): raise Exception("Config path must be a string or a dictionary.") else: raise Exception("Config path not passed in.") - + if config is None: raise Exception("No completion config in the config file") - - models_with_config = config["model"].keys() + + models_with_config = config["model"].keys() model = kwargs["model"] messages = kwargs["messages"] @@ -6319,13 +7321,16 @@ def completion_with_config(config: Union[dict, str], **kwargs): trim_messages_flag = config.get("trim_messages", False) prompt_larger_than_model = False max_model = model - try: + try: max_tokens = litellm.get_max_tokens(model)["max_tokens"] except: - max_tokens = 2048 # assume curr model's max window is 2048 tokens + max_tokens = 2048 # assume curr model's max window is 2048 tokens if adapt_to_prompt_size: - ## Pick model based on token window - prompt_tokens = litellm.token_counter(model="gpt-3.5-turbo", text="".join(message["content"] for message in messages)) + ## Pick model based on token window + prompt_tokens = litellm.token_counter( + model="gpt-3.5-turbo", + text="".join(message["content"] for message in messages), + ) try: curr_max_tokens = litellm.get_max_tokens(model)["max_tokens"] except: @@ -6334,7 +7339,9 @@ def completion_with_config(config: Union[dict, str], **kwargs): prompt_larger_than_model = True for available_model in available_models: try: - curr_max_tokens = litellm.get_max_tokens(available_model)["max_tokens"] + curr_max_tokens = litellm.get_max_tokens(available_model)[ + "max_tokens" + ] if curr_max_tokens > max_tokens: max_tokens = curr_max_tokens max_model = available_model @@ -6348,16 +7355,16 @@ def completion_with_config(config: Union[dict, str], **kwargs): kwargs["messages"] = messages kwargs["model"] = model - try: - if model in models_with_config: + try: + if model in models_with_config: ## Moderation check if config["model"][model].get("needs_moderation"): input = " ".join(message["content"] for message in messages) response = litellm.moderation(input=input) flagged = response["results"][0]["flagged"] - if flagged: + if flagged: raise Exception("This response was flagged as inappropriate") - + ## Model-specific Error Handling error_handling = None if config["model"][model].get("error_handling"): @@ -6369,22 +7376,25 @@ def completion_with_config(config: Union[dict, str], **kwargs): except Exception as e: exception_name = type(e).__name__ fallback_model = None - if error_handling and exception_name in error_handling: + if error_handling and exception_name in error_handling: error_handler = error_handling[exception_name] - # either switch model or api key + # either switch model or api key fallback_model = error_handler.get("fallback_model", None) - if fallback_model: + if fallback_model: kwargs["model"] = fallback_model return litellm.completion(**kwargs) raise e - else: + else: return litellm.completion(**kwargs) except Exception as e: if fallback_models: model = fallback_models.pop(0) - return completion_with_fallbacks(model=model, messages=messages, fallbacks=fallback_models) + return completion_with_fallbacks( + model=model, messages=messages, fallbacks=fallback_models + ) raise e + def completion_with_fallbacks(**kwargs): nested_kwargs = kwargs.pop("kwargs", {}) response = None @@ -6402,8 +7412,10 @@ def completion_with_fallbacks(**kwargs): for model in fallbacks: # loop thru all models try: - # check if it's dict or new model string - if isinstance(model, dict): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}]) + # check if it's dict or new model string + if isinstance( + model, dict + ): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}]) kwargs["api_key"] = model.get("api_key", None) kwargs["api_base"] = model.get("api_base", None) model = model.get("model", original_model) @@ -6426,7 +7438,10 @@ def completion_with_fallbacks(**kwargs): print_verbose(f"trying to make completion call with model: {model}") kwargs["litellm_call_id"] = litellm_call_id - kwargs = {**kwargs, **nested_kwargs} # combine the openai + litellm params at the same level + kwargs = { + **kwargs, + **nested_kwargs, + } # combine the openai + litellm params at the same level response = litellm.completion(**kwargs, model=model) print_verbose(f"response: {response}") if response != None: @@ -6441,18 +7456,24 @@ def completion_with_fallbacks(**kwargs): pass return response + def process_system_message(system_message, max_tokens, model): system_message_event = {"role": "system", "content": system_message} system_message_tokens = get_token_count([system_message_event], model) if system_message_tokens > max_tokens: - print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...") + print_verbose( + "`tokentrimmer`: Warning, system message exceeds token limit. Trimming..." + ) # shorten system message to fit within max_tokens - new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model) + new_system_message = shorten_message_to_fit_limit( + system_message_event, max_tokens, model + ) system_message_tokens = get_token_count([new_system_message], model) - + return system_message_event, max_tokens - system_message_tokens + def process_messages(messages, max_tokens, model): # Process messages from older to more recent messages = messages[::-1] @@ -6463,17 +7484,26 @@ def process_messages(messages, max_tokens, model): available_tokens = max_tokens - used_tokens if available_tokens <= 3: break - final_messages = attempt_message_addition(final_messages=final_messages, message=message, available_tokens=available_tokens, max_tokens=max_tokens, model=model) + final_messages = attempt_message_addition( + final_messages=final_messages, + message=message, + available_tokens=available_tokens, + max_tokens=max_tokens, + model=model, + ) return final_messages -def attempt_message_addition(final_messages, message, available_tokens, max_tokens, model): + +def attempt_message_addition( + final_messages, message, available_tokens, max_tokens, model +): temp_messages = [message] + final_messages temp_message_tokens = get_token_count(messages=temp_messages, model=model) if temp_message_tokens <= max_tokens: return temp_messages - + # if temp_message_tokens > max_tokens, try shortening temp_messages elif "function_call" not in message: # fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens) @@ -6483,19 +7513,18 @@ def attempt_message_addition(final_messages, message, available_tokens, max_toke return final_messages + def can_add_message(message, messages, max_tokens, model): if get_token_count(messages + [message], model) <= max_tokens: return True return False + def get_token_count(messages, model): return token_counter(model=model, messages=messages) -def shorten_message_to_fit_limit( - message, - tokens_needed, - model): +def shorten_message_to_fit_limit(message, tokens_needed, model): """ Shorten a message to fit within a token limit by removing characters from the middle. """ @@ -6503,7 +7532,7 @@ def shorten_message_to_fit_limit( # For OpenAI models, even blank messages cost 7 token, # and if the buffer is less than 3, the while loop will never end, # hence the value 10. - if 'gpt' in model and tokens_needed <= 10: + if "gpt" in model and tokens_needed <= 10: return message content = message["content"] @@ -6515,21 +7544,22 @@ def shorten_message_to_fit_limit( break ratio = (tokens_needed) / total_tokens - - new_length = int(len(content) * ratio) -1 + + new_length = int(len(content) * ratio) - 1 new_length = max(0, new_length) half_length = new_length // 2 left_half = content[:half_length] right_half = content[-half_length:] - trimmed_content = left_half + '..' + right_half + trimmed_content = left_half + ".." + right_half message["content"] = trimmed_content content = trimmed_content return message -# LiteLLM token trimmer + +# LiteLLM token trimmer # this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py # Credits for this code go to Killian Lucas def trim_messages( @@ -6537,8 +7567,8 @@ def trim_messages( model: Optional[str] = None, trim_ratio: float = 0.75, return_response_tokens: bool = False, - max_tokens = None - ): + max_tokens=None, +): """ Trim a list of messages to fit within a model's token limit. @@ -6560,18 +7590,18 @@ def trim_messages( if max_tokens == None: # Check if model is valid if model in litellm.model_cost: - max_tokens_for_model = litellm.model_cost[model]['max_tokens'] + max_tokens_for_model = litellm.model_cost[model]["max_tokens"] max_tokens = int(max_tokens_for_model * trim_ratio) else: - # if user did not specify max tokens + # if user did not specify max tokens # or passed an llm litellm does not know # do nothing, just return messages - return - - system_message = "" + return + + system_message = "" for message in messages: if message["role"] == "system": - system_message += '\n' if system_message else '' + system_message += "\n" if system_message else "" system_message += message["content"] current_tokens = token_counter(model=model, messages=messages) @@ -6579,38 +7609,47 @@ def trim_messages( # Do nothing if current tokens under messages if current_tokens < max_tokens: - return messages - - #### Trimming messages if current_tokens > max_tokens - print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}") - if system_message: - system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model) + return messages - if max_tokens == 0: # the system messages are too long + #### Trimming messages if current_tokens > max_tokens + print_verbose( + f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}" + ) + if system_message: + system_message_event, max_tokens = process_system_message( + system_message=system_message, max_tokens=max_tokens, model=model + ) + + if max_tokens == 0: # the system messages are too long return [system_message_event] - - # Since all system messages are combined and trimmed to fit the max_tokens, + + # Since all system messages are combined and trimmed to fit the max_tokens, # we remove all system messages from the messages list messages = [message for message in messages if message["role"] != "system"] - final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model) + final_messages = process_messages( + messages=messages, max_tokens=max_tokens, model=model + ) # Add system message to the beginning of the final messages if system_message: final_messages = [system_message_event] + final_messages - if return_response_tokens: # if user wants token count with new trimmed messages + if ( + return_response_tokens + ): # if user wants token count with new trimmed messages response_tokens = max_tokens - get_token_count(final_messages, model) return final_messages, response_tokens return final_messages - except Exception as e: # [NON-Blocking, if error occurs just return final_messages + except Exception as e: # [NON-Blocking, if error occurs just return final_messages print_verbose(f"Got exception while token trimming{e}") return messages + def get_valid_models(): """ Returns a list of valid LLMs based on the set environment variables - + Args: None @@ -6628,13 +7667,13 @@ def get_valid_models(): # edge case litellm has together_ai as a provider, it should be togetherai provider = provider.replace("_", "") - # litellm standardizes expected provider keys to + # litellm standardizes expected provider keys to # PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY expected_provider_key = f"{provider.upper()}_API_KEY" - if expected_provider_key in environ_keys: - # key is set + if expected_provider_key in environ_keys: + # key is set valid_providers.append(provider) - + for provider in valid_providers: if provider == "azure": valid_models.append("Azure-LLM") @@ -6643,7 +7682,8 @@ def get_valid_models(): valid_models.extend(models_for_provider) return valid_models except: - return [] # NON-Blocking + return [] # NON-Blocking + # used for litellm.text_completion() to transform HF logprobs to OpenAI.Completion() format def transform_logprobs(hf_response): @@ -6653,40 +7693,39 @@ def transform_logprobs(hf_response): # For each Hugging Face response, transform the logprobs for response in hf_response: # Extract the relevant information from the response - response_details = response['details'] + response_details = response["details"] top_tokens = response_details.get("top_tokens", {}) # Initialize an empty list for the token information token_info = { - 'tokens': [], - 'token_logprobs': [], - 'text_offset': [], - 'top_logprobs': [], + "tokens": [], + "token_logprobs": [], + "text_offset": [], + "top_logprobs": [], } - for i, token in enumerate(response_details['prefill']): + for i, token in enumerate(response_details["prefill"]): # Extract the text of the token - token_text = token['text'] + token_text = token["text"] # Extract the logprob of the token - token_logprob = token['logprob'] + token_logprob = token["logprob"] # Add the token information to the 'token_info' list - token_info['tokens'].append(token_text) - token_info['token_logprobs'].append(token_logprob) + token_info["tokens"].append(token_text) + token_info["token_logprobs"].append(token_logprob) # stub this to work with llm eval harness - top_alt_tokens = { "": -1, "": -2, "": -3 } - token_info['top_logprobs'].append(top_alt_tokens) + top_alt_tokens = {"": -1, "": -2, "": -3} + token_info["top_logprobs"].append(top_alt_tokens) # For each element in the 'tokens' list, extract the relevant information - for i, token in enumerate(response_details['tokens']): - + for i, token in enumerate(response_details["tokens"]): # Extract the text of the token - token_text = token['text'] + token_text = token["text"] # Extract the logprob of the token - token_logprob = token['logprob'] + token_logprob = token["logprob"] top_alt_tokens = {} temp_top_logprobs = [] @@ -6700,13 +7739,15 @@ def transform_logprobs(hf_response): top_alt_tokens[text] = logprob # Add the token information to the 'token_info' list - token_info['tokens'].append(token_text) - token_info['token_logprobs'].append(token_logprob) - token_info['top_logprobs'].append(top_alt_tokens) + token_info["tokens"].append(token_text) + token_info["token_logprobs"].append(token_logprob) + token_info["top_logprobs"].append(top_alt_tokens) # Add the text offset of the token # This is computed as the sum of the lengths of all previous tokens - token_info['text_offset'].append(sum(len(t['text']) for t in response_details['tokens'][:i])) + token_info["text_offset"].append( + sum(len(t["text"]) for t in response_details["tokens"][:i]) + ) # Add the 'token_info' list to the 'transformed_logprobs' list transformed_logprobs = token_info diff --git a/ui/admin.py b/ui/admin.py index 438f0f0a9..4a58080fa 100644 --- a/ui/admin.py +++ b/ui/admin.py @@ -2,6 +2,7 @@ Admin sets proxy url + allowed email subdomain """ from dotenv import load_dotenv + load_dotenv() import streamlit as st import base64, os @@ -9,34 +10,47 @@ import base64, os # Replace your_base_url with the actual URL where the proxy auth app is hosted your_base_url = os.getenv("BASE_URL") # Example base URL + # Function to encode the configuration def encode_config(proxy_url, allowed_email_subdomain): - combined_string = f"proxy_url={proxy_url}&accepted_email_subdomain={allowed_email_subdomain}" - return base64.b64encode(combined_string.encode('utf-8')).decode('utf-8') + combined_string = ( + f"proxy_url={proxy_url}&accepted_email_subdomain={allowed_email_subdomain}" + ) + return base64.b64encode(combined_string.encode("utf-8")).decode("utf-8") + # Simple function to update config values def update_config(proxy_url, allowed_email_subdomain): - st.session_state['proxy_url'] = proxy_url - st.session_state['allowed_email_subdomain'] = allowed_email_subdomain - st.session_state['user_auth_url'] = f"{your_base_url}/?page={encode_config(proxy_url=proxy_url, allowed_email_subdomain=allowed_email_subdomain)}" + st.session_state["proxy_url"] = proxy_url + st.session_state["allowed_email_subdomain"] = allowed_email_subdomain + st.session_state[ + "user_auth_url" + ] = f"{your_base_url}/?page={encode_config(proxy_url=proxy_url, allowed_email_subdomain=allowed_email_subdomain)}" + def admin_page(): # Display the form for the admin to set the proxy URL and allowed email subdomain st.header("Admin Configuration") # Create a configuration placeholder - st.session_state.setdefault('proxy_url', 'http://example.com') - st.session_state.setdefault('allowed_email_subdomain', 'example.com') - st.session_state.setdefault('user_auth_url', 'NOT_GIVEN') + st.session_state.setdefault("proxy_url", "http://example.com") + st.session_state.setdefault("allowed_email_subdomain", "example.com") + st.session_state.setdefault("user_auth_url", "NOT_GIVEN") with st.form("config_form", clear_on_submit=False): - proxy_url = st.text_input("Set Proxy URL", st.session_state['proxy_url']) - allowed_email_subdomain = st.text_input("Set Allowed Email Subdomain", st.session_state['allowed_email_subdomain']) + proxy_url = st.text_input("Set Proxy URL", st.session_state["proxy_url"]) + allowed_email_subdomain = st.text_input( + "Set Allowed Email Subdomain", st.session_state["allowed_email_subdomain"] + ) submitted = st.form_submit_button("Save") if submitted: - update_config(proxy_url=proxy_url, allowed_email_subdomain=allowed_email_subdomain) + update_config( + proxy_url=proxy_url, allowed_email_subdomain=allowed_email_subdomain + ) # Display the current configuration st.write(f"Current Proxy URL: {st.session_state['proxy_url']}") - st.write(f"Current Allowed Email Subdomain: {st.session_state['allowed_email_subdomain']}") - st.write(f"Current User Auth URL: {st.session_state['user_auth_url']}") \ No newline at end of file + st.write( + f"Current Allowed Email Subdomain: {st.session_state['allowed_email_subdomain']}" + ) + st.write(f"Current User Auth URL: {st.session_state['user_auth_url']}") diff --git a/ui/app.py b/ui/app.py index 5d49d45f8..ffb91e01e 100644 --- a/ui/app.py +++ b/ui/app.py @@ -2,6 +2,7 @@ Routes between admin, auth, keys pages """ from dotenv import load_dotenv + load_dotenv() import streamlit as st import base64, binascii, os @@ -9,18 +10,20 @@ from admin import admin_page from auth import auth_page from urllib.parse import urlparse, parse_qs + # Parse the query params in the URL def get_query_params(): # Get the query params from Streamlit's `server.request` function # This functionality is not officially documented and could change in the future versions of Streamlit - query_params = st.experimental_get_query_params() - return query_params + query_params = st.experimental_get_query_params() + return query_params + def is_base64(sb): try: if isinstance(sb, str): # Try to encode it to bytes if it's a unicode string - sb_bytes = sb.encode('ascii') + sb_bytes = sb.encode("ascii") elif isinstance(sb, bytes): sb_bytes = sb else: @@ -36,10 +39,11 @@ def is_base64(sb): except (binascii.Error, ValueError): # If an error occurs, return False, as the input is not base64 return False - + + # Check the URL path and route to the correct page based on the path query_params = get_query_params() -page_param = query_params.get('page', [None])[0] +page_param = query_params.get("page", [None])[0] # Route to the appropriate page based on the URL query param if page_param: diff --git a/ui/auth.py b/ui/auth.py index d94e30ffb..cbcfabe65 100644 --- a/ui/auth.py +++ b/ui/auth.py @@ -6,9 +6,11 @@ Uses supabase passwordless auth: https://supabase.com/docs/reference/python/auth Remember to set your redirect url to 8501 (streamlit default). """ import logging + logging.basicConfig(level=logging.DEBUG) import streamlit as st from dotenv import load_dotenv + load_dotenv() import os from supabase import create_client, Client @@ -20,14 +22,13 @@ supabase: Client = create_client(url, key) def sign_in_with_otp(email: str, redirect_url: str): - data = supabase.auth.sign_in_with_otp({"email": email, - "options": { - "email_redirect_to": redirect_url - }}) + data = supabase.auth.sign_in_with_otp( + {"email": email, "options": {"email_redirect_to": redirect_url}} + ) print(f"data: {data}") # Redirect to Supabase UI with the return data st.write(f"Please check your email for a login link!") - + # Create the Streamlit app def auth_page(redirect_url: str):