From c2e2e927fb68a3bafc1606d3027caccadb882918 Mon Sep 17 00:00:00 2001 From: chabala98 Date: Fri, 1 Dec 2023 13:16:35 +0100 Subject: [PATCH 001/125] fix system prompts for replicate --- docs/my-website/docs/providers/replicate.md | 4 +- litellm/llms/replicate.py | 63 ++++++++++++--------- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/docs/my-website/docs/providers/replicate.md b/docs/my-website/docs/providers/replicate.md index d8ab035e1..3384ba35c 100644 --- a/docs/my-website/docs/providers/replicate.md +++ b/docs/my-website/docs/providers/replicate.md @@ -49,8 +49,8 @@ Below are examples on how to call replicate LLMs using liteLLM Model Name | Function Call | Required OS Variables | -----------------------------|----------------------------------------------------------------|--------------------------------------| - replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages)` | `os.environ['REPLICATE_API_KEY']` | - a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages)`| `os.environ['REPLICATE_API_KEY']` | + replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages, supports_system_prompt=True)` | `os.environ['REPLICATE_API_KEY']` | + a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages, supports_system_prompt=True)`| `os.environ['REPLICATE_API_KEY']` | replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` | daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` | custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` | diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index d639a8d1e..874b31bd6 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -169,6 +169,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos else: # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}") + # Function to extract version ID from model string def model_to_version_id(model): @@ -194,41 +195,47 @@ def completion( ): # Start a prediction and get the prediction URL version_id = model_to_version_id(model) - ## Load Config config = litellm.ReplicateConfig.get_config() for k, v in config.items(): if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v - - if "meta/llama-2-13b-chat" in model: - system_prompt = "" - prompt = "" - for message in messages: - if message["role"] == "system": - system_prompt = message["content"] - else: - prompt += message["content"] - input_data = { - "system_prompt": system_prompt, - "prompt": prompt, - **optional_params - } + + system_prompt = None + if optional_params is not None and "supports_system_prompt" in optional_params: + supports_sys_prompt = optional_params.pop("supports_system_prompt") else: - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", {}), - initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - bos_token=model_prompt_details.get("bos_token", ""), - eos_token=model_prompt_details.get("eos_token", ""), - messages=messages, - ) - else: - prompt = prompt_factory(model=model, messages=messages) + supports_sys_prompt = False + + if supports_sys_prompt: + for i in range(len(messages)): + if messages[i]["role"] == "system": + first_sys_message = messages.pop(i) + system_prompt = first_sys_message["content"] + break + + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", {}), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + bos_token=model_prompt_details.get("bos_token", ""), + eos_token=model_prompt_details.get("eos_token", ""), + messages=messages, + ) + else: + prompt = prompt_factory(model=model, messages=messages) + # If system prompt is supported, and a system prompt is provided, use it + if system_prompt is not None: + input_data = { + "prompt": prompt, + "system_prompt": system_prompt + } + # Otherwise, use the prompt as is + else: input_data = { "prompt": prompt, **optional_params From f2625bca243e1bff96cb8a613173010084ece914 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 08:07:04 -0800 Subject: [PATCH 002/125] (docs) quick start proxy --- docs/my-website/docs/proxy/quick_start.md | 125 +++++++++++----------- 1 file changed, 63 insertions(+), 62 deletions(-) diff --git a/docs/my-website/docs/proxy/quick_start.md b/docs/my-website/docs/proxy/quick_start.md index f2ae9ea22..f1749bc50 100644 --- a/docs/my-website/docs/proxy/quick_start.md +++ b/docs/my-website/docs/proxy/quick_start.md @@ -87,69 +87,8 @@ print(response) -## Quick Start - LiteLLM Proxy + Config.yaml -The config allows you to create a model list and set `api_base`, `max_tokens` (all litellm params). See more details about the config [here](https://docs.litellm.ai/docs/proxy/configs) -### Create a Config for LiteLLM Proxy -Example config - -```yaml -model_list: - - model_name: gpt-3.5-turbo - litellm_params: - model: azure/ - api_base: - api_key: - - model_name: gpt-3.5-turbo - litellm_params: - model: azure/gpt-turbo-small-ca - api_base: https://my-endpoint-canada-berri992.openai.azure.com/ - api_key: -``` - -### Run proxy with config - -```shell -litellm --config your_config.yaml -``` - -## Quick Start Docker Image: Github Container Registry - -### Pull the litellm ghcr docker image -See the latest available ghcr docker image here: -https://github.com/berriai/litellm/pkgs/container/litellm - -```shell -docker pull ghcr.io/berriai/litellm:main-v1.10.1 -``` - -### Run the Docker Image -```shell -docker run ghcr.io/berriai/litellm:main-v1.10.0 -``` - -#### Run the Docker Image with LiteLLM CLI args - -See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli): - -Here's how you can run the docker image and pass your config to `litellm` -```shell -docker run ghcr.io/berriai/litellm:main-v1.10.0 --config your_config.yaml -``` - -Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8` -```shell -docker run ghcr.io/berriai/litellm:main-v1.10.0 --port 8002 --num_workers 8 -``` - -## Server Endpoints -- POST `/chat/completions` - chat completions endpoint to call 100+ LLMs -- POST `/completions` - completions endpoint -- POST `/embeddings` - embedding endpoint for Azure, OpenAI, Huggingface endpoints -- GET `/models` - available models on server -- POST `/key/generate` - generate a key to access the proxy - -## Supported LLMs +### Supported LLMs All LiteLLM supported LLMs are supported on the Proxy. Seel all [supported llms](https://docs.litellm.ai/docs/providers) @@ -301,6 +240,68 @@ $ litellm --model command-nightly +## Quick Start - LiteLLM Proxy + Config.yaml +The config allows you to create a model list and set `api_base`, `max_tokens` (all litellm params). See more details about the config [here](https://docs.litellm.ai/docs/proxy/configs) + +### Create a Config for LiteLLM Proxy +Example config + +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: azure/ + api_base: + api_key: + - model_name: gpt-3.5-turbo + litellm_params: + model: azure/gpt-turbo-small-ca + api_base: https://my-endpoint-canada-berri992.openai.azure.com/ + api_key: +``` + +### Run proxy with config + +```shell +litellm --config your_config.yaml +``` + +## Quick Start Docker Image: Github Container Registry + +### Pull the litellm ghcr docker image +See the latest available ghcr docker image here: +https://github.com/berriai/litellm/pkgs/container/litellm + +```shell +docker pull ghcr.io/berriai/litellm:main-v1.10.1 +``` + +### Run the Docker Image +```shell +docker run ghcr.io/berriai/litellm:main-v1.10.0 +``` + +#### Run the Docker Image with LiteLLM CLI args + +See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli): + +Here's how you can run the docker image and pass your config to `litellm` +```shell +docker run ghcr.io/berriai/litellm:main-v1.10.0 --config your_config.yaml +``` + +Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8` +```shell +docker run ghcr.io/berriai/litellm:main-v1.10.0 --port 8002 --num_workers 8 +``` + +## Server Endpoints +- POST `/chat/completions` - chat completions endpoint to call 100+ LLMs +- POST `/completions` - completions endpoint +- POST `/embeddings` - embedding endpoint for Azure, OpenAI, Huggingface endpoints +- GET `/models` - available models on server +- POST `/key/generate` - generate a key to access the proxy + ## Using with OpenAI compatible projects Set `base_url` to the LiteLLM Proxy server From b1bd799be86cafdb008ae4360fe886e242708879 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 08:26:06 -0800 Subject: [PATCH 003/125] (feat) proxy: use custom_logger --- litellm/proxy/custom_logger.py | 19 +++++++++++++++++++ litellm/proxy/proxy_config.yaml | 3 +++ 2 files changed, 22 insertions(+) create mode 100644 litellm/proxy/custom_logger.py diff --git a/litellm/proxy/custom_logger.py b/litellm/proxy/custom_logger.py new file mode 100644 index 000000000..8a1a824ac --- /dev/null +++ b/litellm/proxy/custom_logger.py @@ -0,0 +1,19 @@ +from litellm.integrations.custom_logger import CustomLogger +class MyCustomHandler(CustomLogger): + def log_pre_api_call(self, model, messages, kwargs): + print(f"Pre-API Call") + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + # log: key, user, model, prompt, response, tokens, cost + print(f"Post-API Call") + + def log_stream_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Stream") + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Success") + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Failure") + +customHandler = MyCustomHandler() diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index e7848f0dc..2d45dd91b 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -3,6 +3,9 @@ model_list: litellm_params: model: gpt-3.5-turbo +litellm_settings: + callbacks: [custom_logger.customHandler] # sets litellm.callbacks = [module.module_variable] + general_settings: # otel: True # OpenTelemetry Logger # master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) From ed8c666922796792af72ff7ce1fe1fc3e1d30f06 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 08:44:20 -0800 Subject: [PATCH 004/125] (feat) proxy: custom_logger for I/O logging --- litellm/proxy/proxy_server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c9ba4c215..a2605dd55 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -117,7 +117,9 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"], ) -def log_input_output(request, response): +def log_input_output(request, response, custom_logger=None): + if custom_logger is not None: + custom_logger(request, response) global otel_logging if otel_logging != True: return From 65e00b438ea348560356fee47f6f0b3edcaa1667 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 09:09:54 -0800 Subject: [PATCH 005/125] (feat) proxy-read litellm custom callback class --- litellm/proxy/custom_logger.py | 15 ++++++++++++++- litellm/proxy/proxy_server.py | 19 +++++++++++++++++++ pyproject.toml | 1 + 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/custom_logger.py b/litellm/proxy/custom_logger.py index 8a1a824ac..d30722bd9 100644 --- a/litellm/proxy/custom_logger.py +++ b/litellm/proxy/custom_logger.py @@ -1,4 +1,5 @@ from litellm.integrations.custom_logger import CustomLogger +import litellm class MyCustomHandler(CustomLogger): def log_pre_api_call(self, model, messages, kwargs): print(f"Pre-API Call") @@ -6,6 +7,16 @@ class MyCustomHandler(CustomLogger): def log_post_api_call(self, kwargs, response_obj, start_time, end_time): # log: key, user, model, prompt, response, tokens, cost print(f"Post-API Call") + print("\n kwargs\n") + print(kwargs) + model = kwargs["model"] + messages = kwargs["messages"] + cost = litellm.completion_cost(completion_response=response_obj) + + # tokens used in response + usage = response_obj.usage + print(usage) + def log_stream_event(self, kwargs, response_obj, start_time, end_time): print(f"On Stream") @@ -16,4 +27,6 @@ class MyCustomHandler(CustomLogger): def log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Failure") -customHandler = MyCustomHandler() +proxy_handler_instance = MyCustomHandler() + +# need to set litellm.callbacks = [customHandler] # on the proxy diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a2605dd55..a0e9250ac 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -6,6 +6,7 @@ from typing import Optional, List import secrets, subprocess import hashlib, uuid import warnings +import importlib messages: list = [] sys.path.insert( 0, os.path.abspath("../..") @@ -556,6 +557,24 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): port=cache_port, password=cache_password ) + elif key == "callbacks": + print(f"{blue_color_code}\nSetting custom callbacks on Proxy") + print() + passed_module, instance_name = value.split(".") + + # Dynamically import the module + module = importlib.import_module(passed_module) + # Get the instance from the module + instance = getattr(module, instance_name) + + methods = [method for method in dir(instance) if callable(getattr(instance, method))] + # Print the methods + print("Methods in the instance:") + for method in methods: + print(method) + + litellm.callbacks = [instance] + else: setattr(litellm, key, value) diff --git a/pyproject.toml b/pyproject.toml index 37f87fa45..2befc1383 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ proxy = [ "backoff", "rq", "orjson", + "importlib", ] extra_proxy = [ From 6599263a855c1172759f8824b10657b3b72434bd Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 09:10:12 -0800 Subject: [PATCH 006/125] (feat) proxy: custom callbacks --- litellm/proxy/proxy_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 2d45dd91b..8cf8ada65 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -4,7 +4,7 @@ model_list: model: gpt-3.5-turbo litellm_settings: - callbacks: [custom_logger.customHandler] # sets litellm.callbacks = [module.module_variable] + callbacks: custom_logger.proxy_handler_instance # sets litellm.callbacks = [module.module_variable] general_settings: # otel: True # OpenTelemetry Logger From 0d44f5e441364103de452cd85d7355d70f6beba6 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 09:23:55 -0800 Subject: [PATCH 007/125] (feat) proxy:custom_logger --- litellm/proxy/custom_logger.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/litellm/proxy/custom_logger.py b/litellm/proxy/custom_logger.py index d30722bd9..544283723 100644 --- a/litellm/proxy/custom_logger.py +++ b/litellm/proxy/custom_logger.py @@ -5,24 +5,35 @@ class MyCustomHandler(CustomLogger): print(f"Pre-API Call") def log_post_api_call(self, kwargs, response_obj, start_time, end_time): - # log: key, user, model, prompt, response, tokens, cost print(f"Post-API Call") - print("\n kwargs\n") - print(kwargs) - model = kwargs["model"] - messages = kwargs["messages"] - cost = litellm.completion_cost(completion_response=response_obj) - # tokens used in response - usage = response_obj.usage - print(usage) - - def log_stream_event(self, kwargs, response_obj, start_time, end_time): print(f"On Stream") def log_success_event(self, kwargs, response_obj, start_time, end_time): print(f"On Success") + # log: key, user, model, prompt, response, tokens, cost + print("\n kwargs\n") + print(kwargs) + model = kwargs["model"] + messages = kwargs["messages"] + cost = litellm.completion_cost(completion_response=response_obj) + response = response_obj + # tokens used in response + usage = response_obj["usage"] + + print( + f""" + Model: {model}, + Messages: {messages}, + Usage: {usage}, + Cost: {cost}, + Response: {response} + """ + ) + + print(usage) + def log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Failure") From 31d9762b506ec5ce2783b38e02a9621846242ab2 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Dec 2023 09:36:37 -0800 Subject: [PATCH 008/125] fix(model_management.md): add docs on model management on proxy --- docs/my-website/docs/proxy/configs.md | 142 +++++++----------- .../my-website/docs/proxy/model_management.md | 74 +++++++++ docs/my-website/docs/proxy/virtual_keys.md | 3 +- docs/my-website/sidebars.js | 1 + 4 files changed, 131 insertions(+), 89 deletions(-) create mode 100644 docs/my-website/docs/proxy/model_management.md diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index de95ce94a..71ce7de02 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -8,71 +8,46 @@ Set model list, `api_base`, `api_key`, `temperature` & proxy server settings (`m | `general_settings` | Server settings, example setting `master_key: sk-my_special_key` | | `environment_variables` | Environment Variables example, `REDIS_HOST`, `REDIS_PORT` | -#### Example Config +## Quick Start + +Set a model alias for your deployments. + +In the `config.yaml` the model_name parameter is the user-facing name to use for your deployment. + +In the config below requests with: +- `model=vllm-models` will route to `openai/facebook/opt-125m`. +- `model=gpt-3.5-turbo` will load balance between `azure/gpt-turbo-small-eu` and `azure/gpt-turbo-small-ca` + ```yaml model_list: - - model_name: gpt-3.5-turbo - litellm_params: + - model_name: gpt-3.5-turbo # user-facing model alias + litellm_params: # all params accepted by litellm.completion() - https://docs.litellm.ai/docs/completion/input model: azure/gpt-turbo-small-eu api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ - api_key: + api_key: "os.environ/AZURE_API_KEY_EU" # does os.getenv("AZURE_API_KEY_EU") rpm: 6 # Rate limit for this deployment: in requests per minute (rpm) - model_name: gpt-3.5-turbo litellm_params: model: azure/gpt-turbo-small-ca api_base: https://my-endpoint-canada-berri992.openai.azure.com/ - api_key: + api_key: "os.environ/AZURE_API_KEY_CA" rpm: 6 - - model_name: gpt-3.5-turbo + - model_name: vllm-models litellm_params: - model: azure/gpt-turbo-large - api_base: https://openai-france-1234.openai.azure.com/ - api_key: + model: openai/facebook/opt-125m # the `openai/` prefix tells litellm it's openai compatible + api_base: http://0.0.0.0:8000 rpm: 1440 + model_info: + version: 2 -litellm_settings: +litellm_settings: # module level litellm settings - https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py drop_params: True set_verbose: True general_settings: master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) - - -environment_variables: - OPENAI_API_KEY: sk-123 - REPLICATE_API_KEY: sk-cohere-is-okay - REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com - REDIS_PORT: "16337" - REDIS_PASSWORD: ``` -### Config for Multiple Models - GPT-4, Claude-2 - -Here's how you can use multiple llms with one proxy `config.yaml`. - -#### Step 1: Setup Config -```yaml -model_list: - - model_name: zephyr-alpha # the 1st model is the default on the proxy - litellm_params: # params for litellm.completion() - https://docs.litellm.ai/docs/completion/input#input---request-body - model: huggingface/HuggingFaceH4/zephyr-7b-alpha - api_base: http://0.0.0.0:8001 - - model_name: gpt-4 - litellm_params: - model: gpt-4 - api_key: sk-1233 - - model_name: claude-2 - litellm_params: - model: claude-2 - api_key: sk-claude -``` - -:::info - -The proxy uses the first model in the config as the default model - in this config the default model is `zephyr-alpha` -::: - - #### Step 2: Start Proxy with config ```shell @@ -96,32 +71,11 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \ ' ``` -### Config for Embedding Models - xorbitsai/inference - -Here's how you can use multiple llms with one proxy `config.yaml`. -Here is how [LiteLLM calls OpenAI Compatible Embedding models](https://docs.litellm.ai/docs/embedding/supported_embedding#openai-compatible-embedding-models) - -#### Config -```yaml -model_list: - - model_name: custom_embedding_model - litellm_params: - model: openai/custom_embedding # the `openai/` prefix tells litellm it's openai compatible - api_base: http://0.0.0.0:8000/ - - model_name: custom_embedding_model - litellm_params: - model: openai/custom_embedding # the `openai/` prefix tells litellm it's openai compatible - api_base: http://0.0.0.0:8001/ -``` - -Run the proxy using this config -```shell -$ litellm --config /path/to/config.yaml -``` - -### Save Model-specific params (API Base, API Keys, Temperature, Headers etc.) +## Save Model-specific params (API Base, API Keys, Temperature, Headers etc.) You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc. +[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1) + **Step 1**: Create a `config.yaml` file ```yaml model_list: @@ -152,9 +106,11 @@ model_list: $ litellm --config /path/to/config.yaml ``` -### Load API Keys from Vault +## Load API Keys -If you have secrets saved in Azure Vault, etc. and don't want to expose them in the config.yaml, here's how to load model-specific keys from the environment. +### Load API Keys from Environment + +If you have secrets saved in your environment, and don't want to expose them in the config.yaml, here's how to load model-specific keys from the environment. ```python os.environ["AZURE_NORTH_AMERICA_API_KEY"] = "your-azure-api-key" @@ -174,30 +130,42 @@ model_list: s/o to [@David Manouchehri](https://www.linkedin.com/in/davidmanouchehri/) for helping with this. -### Config for setting Model Aliases +### Load API Keys from Azure Vault -Set a model alias for your deployments. +1. Install Proxy dependencies +```bash +$ pip install litellm[proxy] litellm[extra_proxy] +``` -In the `config.yaml` the model_name parameter is the user-facing name to use for your deployment. - -In the config below requests with `model=gpt-4` will route to `ollama/llama2` +2. Save Azure details in your environment +```bash +export["AZURE_CLIENT_ID"]="your-azure-app-client-id" +export["AZURE_CLIENT_SECRET"]="your-azure-app-client-secret" +export["AZURE_TENANT_ID"]="your-azure-tenant-id" +export["AZURE_KEY_VAULT_URI"]="your-azure-key-vault-uri" +``` +3. Add to proxy config.yaml ```yaml -model_list: - - model_name: text-davinci-003 - litellm_params: - model: ollama/zephyr - - model_name: gpt-4 - litellm_params: - model: ollama/llama2 - - model_name: gpt-3.5-turbo - litellm_params: - model: ollama/llama2 +model_list: + - model_name: "my-azure-models" # model alias + litellm_params: + model: "azure/" + api_key: "os.environ/AZURE-API-KEY" # reads from key vault - get_secret("AZURE_API_KEY") + api_base: "os.environ/AZURE-API-BASE" # reads from key vault - get_secret("AZURE_API_BASE") + +general_settings: + use_azure_key_vault: True +``` + +You can now test this by starting your proxy: +```bash +litellm --config /path/to/config.yaml ``` ### Set Custom Prompt Templates -LiteLLM by default checks if a model has a [prompt template and applies it](./completion/prompt_formatting.md) (e.g. if a huggingface model has a saved chat template in it's tokenizer_config.json). However, you can also set a custom prompt template on your proxy in the `config.yaml`: +LiteLLM by default checks if a model has a [prompt template and applies it](../completion/prompt_formatting.md) (e.g. if a huggingface model has a saved chat template in it's tokenizer_config.json). However, you can also set a custom prompt template on your proxy in the `config.yaml`: **Step 1**: Save your prompt template in a `config.yaml` ```yaml diff --git a/docs/my-website/docs/proxy/model_management.md b/docs/my-website/docs/proxy/model_management.md new file mode 100644 index 000000000..0cd4ab829 --- /dev/null +++ b/docs/my-website/docs/proxy/model_management.md @@ -0,0 +1,74 @@ +# Model Management +Add new models + Get model info without restarting proxy. + +## Get Model Information + +Retrieve detailed information about each model listed in the `/models` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes. + + + + +```bash +curl -X GET "http://0.0.0.0:8000/model/info" \ + -H "accept: application/json" \ +``` + + + +## Add a New Model + +Add a new model to the list in the `config.yaml` by providing the model parameters. This allows you to update the model list without restarting the proxy. + + + + +```bash +curl -X POST "http://0.0.0.0:8000/model/new" \ + -H "accept: application/json" \ + -H "Content-Type: application/json" \ + -d '{ "model_name": "azure-gpt-turbo", "litellm_params": {"model": "azure/gpt-3.5-turbo", "api_key": "os.environ/AZURE_API_KEY", "api_base": "my-azure-api-base"} }' +``` + + + + +### Model Parameters Structure + +When adding a new model, your JSON payload should conform to the following structure: + +- `model_name`: The name of the new model (required). +- `litellm_params`: A dictionary containing parameters specific to the Litellm setup (required). +- `model_info`: An optional dictionary to provide additional information about the model. + +Here's an example of how to structure your `ModelParams`: + +```json +{ + "model_name": "my_awesome_model", + "litellm_params": { + "some_parameter": "some_value", + "another_parameter": "another_value" + }, + "model_info": { + "author": "Your Name", + "version": "1.0", + "description": "A brief description of the model." + } +} +``` +--- + +Keep in mind that as both endpoints are in [BETA], you may need to visit the associated GitHub issues linked in the API descriptions to check for updates or provide feedback: + +- Get Model Information: [Issue #933](https://github.com/BerriAI/litellm/issues/933) +- Add a New Model: [Issue #964](https://github.com/BerriAI/litellm/issues/964) + +Feedback on the beta endpoints is valuable and helps improve the API for all users. \ No newline at end of file diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index 26c22e47a..ac1cf54de 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -1,5 +1,4 @@ - -# Cost Tracking & Virtual Keys +# Key Management Track Spend and create virtual keys for the proxy Grant other's temporary access to your proxy, with keys that expire after a set duration. diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index dd9aa514f..11f81fa4d 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -99,6 +99,7 @@ const sidebars = { "proxy/configs", "proxy/load_balancing", "proxy/virtual_keys", + "proxy/model_management", "proxy/caching", "proxy/logging", "proxy/cli", From 93f5c266da15fdded0061bf148eb074543416f90 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 09:45:03 -0800 Subject: [PATCH 009/125] (test) test completion: if 'user' passed to API --- litellm/tests/test_completion.py | 47 +++++++++++++++++++++++++++++--- litellm/utils.py | 3 +- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 3fbbdea0e..54bc53326 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -442,9 +442,46 @@ def test_completion_text_openai(): pytest.fail(f"Error occurred: {e}") # test_completion_text_openai() +def custom_callback( + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, end_time # start/end time +): + # Your custom code here + try: + print("LITELLM: in custom callback function") + print("\nkwargs\n", kwargs) + model = kwargs["model"] + messages = kwargs["messages"] + user = kwargs.get("user") + + ################################################# + + print( + f""" + Model: {model}, + Messages: {messages}, + User: {user}, + Seed: {kwargs["seed"]}, + temperature: {kwargs["temperature"]}, + """ + ) + + assert kwargs["user"] == "ishaans app" + assert kwargs["model"] == "gpt-3.5-turbo-1106" + assert kwargs["seed"] == 12 + assert kwargs["temperature"] == 0.5 + except Exception as e: + pytest.fail(f"Error occurred: {e}") + def test_completion_openai_with_optional_params(): + # [Proxy PROD TEST] WARNING: DO NOT DELETE THIS TEST + # assert that `user` gets passed to the completion call + # Note: This tests that we actually send the optional params to the completion call + # We use custom callbacks to test this try: litellm.set_verbose = True + litellm.success_callback = [custom_callback] response = completion( model="gpt-3.5-turbo-1106", messages=[ @@ -458,15 +495,17 @@ def test_completion_openai_with_optional_params(): seed=12, response_format={ "type": "json_object" }, logit_bias=None, + user = "ishaans app" ) # Add any assertions here to check the response + print(response) - except litellm.Timeout as e: - pass + litellm.success_callback = [] # unset callbacks + except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_openai_with_optional_params() +test_completion_openai_with_optional_params() def test_completion_openai_litellm_key(): try: @@ -1337,7 +1376,7 @@ def test_azure_cloudflare_api(): traceback.print_exc() pass -test_azure_cloudflare_api() +# test_azure_cloudflare_api() def test_completion_anyscale_2(): try: diff --git a/litellm/utils.py b/litellm/utils.py index 280a6342f..b756fc358 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -544,7 +544,8 @@ class Logging: "optional_params": self.optional_params, "litellm_params": self.litellm_params, "start_time": self.start_time, - "stream": self.stream + "stream": self.stream, + **self.optional_params } def pre_call(self, input, api_key, model=None, additional_args={}): From f3b939e603a85fdf9e3a67d4d134e26338dd8725 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 09:50:34 -0800 Subject: [PATCH 010/125] (fix) access `user` in custom logger --- litellm/proxy/custom_logger.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/litellm/proxy/custom_logger.py b/litellm/proxy/custom_logger.py index 544283723..06d5fc127 100644 --- a/litellm/proxy/custom_logger.py +++ b/litellm/proxy/custom_logger.py @@ -15,8 +15,13 @@ class MyCustomHandler(CustomLogger): # log: key, user, model, prompt, response, tokens, cost print("\n kwargs\n") print(kwargs) + ### Access kwargs passed to litellm.completion() model = kwargs["model"] messages = kwargs["messages"] + user = kwargs.get("user") + ################################################# + + ### Calculate cost ####################### cost = litellm.completion_cost(completion_response=response_obj) response = response_obj # tokens used in response @@ -26,6 +31,7 @@ class MyCustomHandler(CustomLogger): f""" Model: {model}, Messages: {messages}, + User: {user}, Usage: {usage}, Cost: {cost}, Response: {response} From 88cec3b9ab07a7a58ea5777858cba5a5a39ab5f0 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 10:17:56 -0800 Subject: [PATCH 011/125] (fix) proxy: don't overwrite `user` --- litellm/proxy/proxy_server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a0e9250ac..d1840073a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -908,8 +908,9 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap or model # for azure deployments or data["model"] # default passed in http request ) - - data["user"] = user_api_key_dict.get("user_id", None) + # users can pass in 'user' param to /chat/completions. Don't override it + if data["user"] is None: + data["user"] = user_api_key_dict.get("user_id", None) if "metadata" in data: data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] From ebd9404cfde25ff9e0a5da0f05c884cfdf0c86fc Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 10:19:35 -0800 Subject: [PATCH 012/125] (test) proxy: don't overwrite user --- litellm/tests/test_proxy_server.py | 49 ++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 005de2762..53d9efaa9 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -45,7 +45,7 @@ def test_chat_completion(): pytest.fail("LiteLLM Proxy test failed. Exception", e) # Run the test -test_chat_completion() +# test_chat_completion() def test_chat_completion_azure(): @@ -119,4 +119,49 @@ def test_add_new_model(): except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") -test_add_new_model() \ No newline at end of file +# test_add_new_model() + +from litellm.integrations.custom_logger import CustomLogger +class MyCustomHandler(CustomLogger): + def log_pre_api_call(self, model, messages, kwargs): + print(f"Pre-API Call") + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Success") + assert kwargs["user"] == "proxy-user" + assert kwargs["model"] == "gpt-3.5-turbo" + assert kwargs["max_tokens"] == 10 + +customHandler = MyCustomHandler() + + +def test_chat_completion_optional_params(): + # [PROXY: PROD TEST] - DO NOT DELETE + # This tests if all the /chat/completion params are passed to litellm + + try: + # Your test data + litellm.set_verbose=True + test_data = { + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "hi" + }, + ], + "max_tokens": 10, + "user": "proxy-user" + } + + litellm.callbacks = [customHandler] + print("testing proxy server: optional params") + response = client.post("/v1/chat/completions", json=test_data) + assert response.status_code == 200 + result = response.json() + print(f"Received response: {result}") + except Exception as e: + pytest.fail("LiteLLM Proxy test failed. Exception", e) + +# Run the test +test_chat_completion_optional_params() \ No newline at end of file From d7d8c5f6e6d5e1150d8d62dcd4f241704b105636 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 10:23:55 -0800 Subject: [PATCH 013/125] (fix) proxy --- litellm/proxy/proxy_server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d1840073a..691f87939 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -908,8 +908,10 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap or model # for azure deployments or data["model"] # default passed in http request ) + # users can pass in 'user' param to /chat/completions. Don't override it - if data["user"] is None: + if data.get("user", None) is None: + # if users are using user_api_key_auth, set `user` in `data` data["user"] = user_api_key_dict.get("user_id", None) if "metadata" in data: From 63e55f1865ec0a0d02348c2d2b26811d1b6291ca Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Dec 2023 10:43:42 -0800 Subject: [PATCH 014/125] fix(proxy_server.py): fix /key/generate post endpoint --- litellm/proxy/proxy_server.py | 46 +++++++++-------- litellm/proxy/utils.py | 14 +++++- litellm/tests/test_config.yaml | 7 +++ litellm/tests/test_proxy_server.py | 2 +- litellm/tests/test_proxy_server_keys.py | 66 +++++++++++++++++++++++++ litellm/utils.py | 7 +-- 6 files changed, 115 insertions(+), 27 deletions(-) create mode 100644 litellm/tests/test_config.yaml create mode 100644 litellm/tests/test_proxy_server_keys.py diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 691f87939..9ebafe2dd 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -213,11 +213,11 @@ class GenerateKeyRequest(BaseModel): aliases: dict = {} config: dict = {} spend: int = 0 - user_id: Optional[str] + user_id: Optional[str] = None class GenerateKeyResponse(BaseModel): key: str - expires: str + expires: datetime user_id: str class _DeleteKeyObject(BaseModel): @@ -277,6 +277,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap "api_key": None } try: + if api_key is None: + raise Exception("No api key passed in.") route = request.url.path # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead @@ -491,8 +493,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): ## PRINT YAML FOR CONFIRMING IT WORKS printed_yaml = copy.deepcopy(config) printed_yaml.pop("environment_variables", None) - for model in printed_yaml["model_list"]: - model["litellm_params"].pop("api_key", None) + if "model_list" in printed_yaml: + for model in printed_yaml["model_list"]: + model["litellm_params"].pop("api_key", None) print(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}") @@ -507,22 +510,24 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): if general_settings is None: general_settings = {} if general_settings: - ### MASTER KEY ### - master_key = general_settings.get("master_key", None) - if master_key and master_key.startswith("os.environ/"): - master_key_env_name = master_key.replace("os.environ/", "") - master_key = os.getenv(master_key_env_name) + ### LOAD FROM AZURE KEY VAULT ### + use_azure_key_vault = general_settings.get("use_azure_key_vault", False) + load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) + ### CONNECT TO DATABASE ### database_url = general_settings.get("database_url", None) + if database_url and database_url.startswith("os.environ/"): + database_url = litellm.get_secret(database_url) prisma_setup(database_url=database_url) ## COST TRACKING ## cost_tracking() ### START REDIS QUEUE ### use_queue = general_settings.get("use_queue", False) celery_setup(use_queue=use_queue) - ### LOAD FROM AZURE KEY VAULT ### - use_azure_key_vault = general_settings.get("use_azure_key_vault", False) - load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) + ### MASTER KEY ### + master_key = general_settings.get("master_key", None) + if master_key and master_key.startswith("os.environ/"): + master_key = litellm.get_secret(master_key) #### OpenTelemetry Logging (OTEL) ######## otel_logging = general_settings.get("otel", False) @@ -540,9 +545,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): print(f"{blue_color_code}\nSetting Cache on Proxy") from litellm.caching import Cache cache_type = value["type"] - cache_host = os.environ.get("REDIS_HOST") - cache_port = os.environ.get("REDIS_PORT") - cache_password = os.environ.get("REDIS_PASSWORD") + cache_host = litellm.get_secret("REDIS_HOST") + cache_port = litellm.get_secret("REDIS_PORT") + cache_password = litellm.get_secret("REDIS_PASSWORD") # Assuming cache_type, cache_host, cache_port, and cache_password are strings print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}") @@ -794,12 +799,14 @@ def litellm_completion(*args, **kwargs): return StreamingResponse(data_generator(response), media_type='text/event-stream') return response -@app.on_event("startup") +@router.on_event("startup") async def startup_event(): global prisma_client, master_key import json worker_config = json.loads(os.getenv("WORKER_CONFIG")) + print(f"worker_config: {worker_config}") initialize(**worker_config) + print(f"prisma client - {prisma_client}") if prisma_client: await prisma_client.connect() @@ -807,7 +814,7 @@ async def startup_event(): # add master key to db await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key) -@app.on_event("shutdown") +@router.on_event("shutdown") async def shutdown_event(): global prisma_client if prisma_client: @@ -1022,8 +1029,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest): - key: The generated api key - expires: Datetime object for when key expires. """ - data = await request.json() - + # data = await request.json() duration_str = data.duration # Default to 1 hour if duration is not provided models = data.models # Default to an empty list (meaning allow token to call all models) aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list) @@ -1042,8 +1048,6 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest): @router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) async def delete_key_fn(request: Request, data: DeleteKeyRequest): try: - data = await request.json() - keys = data.keys deleted_keys = await delete_verification_token(tokens=keys) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 532321ad4..7acdd9b4b 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -5,8 +5,18 @@ class PrismaClient: def __init__(self, database_url: str): print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") os.environ["DATABASE_URL"] = database_url - subprocess.run(['prisma', 'generate']) - subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss + # Save the current working directory + original_dir = os.getcwd() + # set the working directory to where this script is + abspath = os.path.abspath(__file__) + dname = os.path.dirname(abspath) + os.chdir(dname) + + try: + subprocess.run(['prisma', 'generate']) + subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss + finally: + os.chdir(original_dir) # Now you can import the Prisma Client from prisma import Client self.db = Client() #Client to connect to Prisma db diff --git a/litellm/tests/test_config.yaml b/litellm/tests/test_config.yaml new file mode 100644 index 000000000..0e678d2d3 --- /dev/null +++ b/litellm/tests/test_config.yaml @@ -0,0 +1,7 @@ +litellm_settings: + drop_params: True + set_verbose: True + +general_settings: + master_key: "os.environ/PROXY_MASTER_KEY" + database_url: "os.environ/PROXY_DATABASE_URL" # [OPTIONAL] use for token-based auth to proxy diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 53d9efaa9..805e2a860 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -164,4 +164,4 @@ def test_chat_completion_optional_params(): pytest.fail("LiteLLM Proxy test failed. Exception", e) # Run the test -test_chat_completion_optional_params() \ No newline at end of file +test_chat_completion_optional_params() diff --git a/litellm/tests/test_proxy_server_keys.py b/litellm/tests/test_proxy_server_keys.py new file mode 100644 index 000000000..b0e7d33bd --- /dev/null +++ b/litellm/tests/test_proxy_server_keys.py @@ -0,0 +1,66 @@ +import sys, os +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os, io + +# this file is to test litellm/proxy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest, logging +import litellm +from litellm import embedding, completion, completion_cost, Timeout +from litellm import RateLimitError +# Configure logging +logging.basicConfig( + level=logging.DEBUG, # Set the desired logging level + format="%(asctime)s - %(levelname)s - %(message)s", +) + +# test /chat/completion request to the proxy +from fastapi.testclient import TestClient +from fastapi import FastAPI +from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined +cwd = os.getcwd() +config_fp = f"{cwd}/test_config.yaml" +save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) +app = FastAPI() +app.include_router(router) # Include your router in the test app +@app.on_event("startup") +async def wrapper_startup_event(): + await startup_event() + +# Here you create a fixture that will be used by your tests +# Make sure the fixture returns TestClient(app) +@pytest.fixture(autouse=True) +def client(): + with TestClient(app) as client: + yield client + +def test_add_new_key(client): + try: + # Your test data + test_data = { + "models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": "20m" + } + print("testing proxy server") + # Your bearer token + token = os.getenv("PROXY_MASTER_KEY") + + headers = { + "Authorization": f"Bearer {token}" + } + response = client.post("/key/generate", json=test_data, headers=headers) + print(f"response: {response.text}") + assert response.status_code == 200 + result = response.json() + print(f"Received response: {result}") + except Exception as e: + pytest.fail("LiteLLM Proxy test failed. Exception", e) + +# # Run the test - only runs via pytest \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index b756fc358..dce259237 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2421,8 +2421,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ return model, custom_llm_provider, dynamic_api_key, api_base if api_key and api_key.startswith("os.environ/"): - api_key_env_name = api_key.replace("os.environ/", "") - dynamic_api_key = get_secret(api_key_env_name) + dynamic_api_key = get_secret(api_key) # check if llm provider part of model name if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list: custom_llm_provider = model.split("/", 1)[0] @@ -4722,7 +4721,9 @@ def litellm_telemetry(data): ######### Secret Manager ############################ # checks if user has passed in a secret manager client # if passed in then checks the secret there -def get_secret(secret_name): +def get_secret(secret_name: str): + if secret_name.startswith("os.environ/"): + secret_name = secret_name.replace("os.environ/", "") if litellm.secret_manager_client is not None: # TODO: check which secret manager is being used # currently only supports Infisical From 9b1e02cdf1d5dd548455727d661a877107a79762 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 10:48:22 -0800 Subject: [PATCH 015/125] (chore) rename: proxy-custom logger --- .../proxy/{custom_logger.py => custom_callbacks.py} | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) rename litellm/proxy/{custom_logger.py => custom_callbacks.py} (79%) diff --git a/litellm/proxy/custom_logger.py b/litellm/proxy/custom_callbacks.py similarity index 79% rename from litellm/proxy/custom_logger.py rename to litellm/proxy/custom_callbacks.py index 06d5fc127..bf4a837cd 100644 --- a/litellm/proxy/custom_logger.py +++ b/litellm/proxy/custom_callbacks.py @@ -1,5 +1,8 @@ from litellm.integrations.custom_logger import CustomLogger import litellm + +# This file includes the custom callbacks for LiteLLM Proxy +# Once defined, these can be passed in proxy_config.yaml class MyCustomHandler(CustomLogger): def log_pre_api_call(self, model, messages, kwargs): print(f"Pre-API Call") @@ -11,14 +14,17 @@ class MyCustomHandler(CustomLogger): print(f"On Stream") def log_success_event(self, kwargs, response_obj, start_time, end_time): - print(f"On Success") # log: key, user, model, prompt, response, tokens, cost + print("\nOn Success\n") print("\n kwargs\n") print(kwargs) ### Access kwargs passed to litellm.completion() model = kwargs["model"] messages = kwargs["messages"] - user = kwargs.get("user") + user = kwargs.get("user", None) + + litellm_params = kwargs.get("litellm_params", {}) + metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here ################################################# ### Calculate cost ####################### @@ -35,6 +41,7 @@ class MyCustomHandler(CustomLogger): Usage: {usage}, Cost: {cost}, Response: {response} + Proxy Metadata: {metadata} """ ) From ec579d48218548087f71bf81a313650ddede54d2 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 10:48:36 -0800 Subject: [PATCH 016/125] (chore) rename proxy: custom_callbacks --- litellm/proxy/proxy_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 8cf8ada65..071a82205 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -4,7 +4,7 @@ model_list: model: gpt-3.5-turbo litellm_settings: - callbacks: custom_logger.proxy_handler_instance # sets litellm.callbacks = [module.module_variable] + callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [module.module_variable] general_settings: # otel: True # OpenTelemetry Logger From e96a60893860623e1fa7e32285a6f4de77e50533 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 10:48:59 -0800 Subject: [PATCH 017/125] (feat) proxy: set custom headers in metadata --- litellm/proxy/proxy_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9ebafe2dd..1238fd8e2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -923,9 +923,10 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap if "metadata" in data: data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] + data["metadata"]["headers"] = request.headers else: data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} - + data["metadata"]["headers"] = request.headers global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli if user_temperature: From bfe0172108ac6ff795db99964f87a8d8ebe3a3aa Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Dec 2023 10:51:25 -0800 Subject: [PATCH 018/125] =?UTF-8?q?bump:=20version=201.10.2=20=E2=86=92=20?= =?UTF-8?q?1.10.3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2befc1383..a9878ced7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.10.2" +version = "1.10.3" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.10.2" +version = "1.10.3" version_files = [ "pyproject.toml:^version" ] From c5b92837c2c2811e74a9afc555d04cbf14c5118b Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 11:20:27 -0800 Subject: [PATCH 019/125] (docs) custom callbacks proxy --- docs/my-website/docs/proxy/logging.md | 124 +++++++++++++++++++++++++- litellm/proxy/custom_callbacks.py | 15 ++-- 2 files changed, 128 insertions(+), 11 deletions(-) diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index 21d53a7dc..851b8a4f1 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -1,5 +1,125 @@ -# Logging - OpenTelemetry, Langfuse, ElasticSearch -Log Proxy Input, Output, Exceptions to Langfuse, OpenTelemetry +# Logging - Custom Callbacks, OpenTelemetry, Langfuse, ElasticSearch +Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry + +## Custom Callbacks +Use this when you want to run custom callbacks in `python` + +### Step 1 - Create your custom `litellm` callback class +We use `litellm.integrations.custom_logger` for this, **more details about litellm custom callbacks [here](https://docs.litellm.ai/docs/observability/custom_callback)** + +Define your custom callback class in a python file. + +Here's an example custom logger for tracking `key, user, model, prompt, response, tokens, cost`. We create a file called `custom_callbacks.py` and initialize `proxy_handler_instance` + +```python +from litellm.integrations.custom_logger import CustomLogger +import litellm + +# This file includes the custom callbacks for LiteLLM Proxy +# Once defined, these can be passed in proxy_config.yaml +class MyCustomHandler(CustomLogger): + def log_pre_api_call(self, model, messages, kwargs): + print(f"Pre-API Call") + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + print(f"Post-API Call") + + def log_stream_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Stream") + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + # log: key, user, model, prompt, response, tokens, cost + print("\nOn Success") + ### Access kwargs passed to litellm.completion() + model = kwargs.get("model", None) + messages = kwargs.get("messages", None) + user = kwargs.get("user", None) + + #### Access litellm_params passed to litellm.completion(), example access `metadata` + litellm_params = kwargs.get("litellm_params", {}) + metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here + ################################################# + + ##### Calculate cost using litellm.completion_cost() ####################### + cost = litellm.completion_cost(completion_response=response_obj) + response = response_obj + # tokens used in response + usage = response_obj["usage"] + + print( + f""" + Model: {model}, + Messages: {messages}, + User: {user}, + Usage: {usage}, + Cost: {cost}, + Response: {response} + Proxy Metadata: {metadata} + """ + ) + return + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Failure") + +proxy_handler_instance = MyCustomHandler() + +# need to set litellm.callbacks = [customHandler] # on the proxy + + +``` + +### Step 2 - Pass your custom callback class in `config.yaml` +We pass the custom callback class defined in **Step1** to the config.yaml. + +Set `callbacks` to `python_filename.logger_instance_name` + +In the config below, the custom callback is defined in a file`custom_callbacks.py` and has an instance of `proxy_handler_instance = MyCustomHandler()`. + +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: gpt-3.5-turbo + +litellm_settings: + callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [module.module_variable] + +``` + +### Step 3 - Start proxy + test request +```shell +litellm --config proxy_config.yaml +``` + +```shell +curl --location 'http://0.0.0.0:8000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --data ' { + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "good morning good sir" + } + ], + "user": "ishaan-app", + "temperature": 0.2 + }' +``` + +#### Resulting Log on Proxy +```shell +On Success + Model: gpt-3.5-turbo, + Messages: [{'role': 'user', 'content': 'good morning good sir'}], + User: ishaan-app, + Usage: {'completion_tokens': 10, 'prompt_tokens': 11, 'total_tokens': 21}, + Cost: 3.65e-05, + Response: {'id': 'chatcmpl-8S8avKJ1aVBg941y5xzGMSKrYCMvN', 'choices': [{'finish_reason': 'stop', 'index': 0, 'message': {'content': 'Good morning! How can I assist you today?', 'role': 'assistant'}}], 'created': 1701716913, 'model': 'gpt-3.5-turbo-0613', 'object': 'chat.completion', 'system_fingerprint': None, 'usage': {'completion_tokens': 10, 'prompt_tokens': 11, 'total_tokens': 21}} + Proxy Metadata: {'user_api_key': None, 'headers': Headers({'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'authorization': 'Bearer sk-1234', 'content-length': '199', 'content-type': 'application/x-www-form-urlencoded'}), 'model_group': 'gpt-3.5-turbo', 'deployment': 'gpt-3.5-turbo-ModelID-gpt-3.5-turbo'} +``` + ## OpenTelemetry, ElasticSearch ### Step 1 Start OpenTelemetry Collecter Docker Container diff --git a/litellm/proxy/custom_callbacks.py b/litellm/proxy/custom_callbacks.py index bf4a837cd..a991ff1d9 100644 --- a/litellm/proxy/custom_callbacks.py +++ b/litellm/proxy/custom_callbacks.py @@ -15,19 +15,18 @@ class MyCustomHandler(CustomLogger): def log_success_event(self, kwargs, response_obj, start_time, end_time): # log: key, user, model, prompt, response, tokens, cost - print("\nOn Success\n") - print("\n kwargs\n") - print(kwargs) + print("\nOn Success") ### Access kwargs passed to litellm.completion() - model = kwargs["model"] - messages = kwargs["messages"] + model = kwargs.get("model", None) + messages = kwargs.get("messages", None) user = kwargs.get("user", None) + #### Access litellm_params passed to litellm.completion(), example access `metadata` litellm_params = kwargs.get("litellm_params", {}) metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here ################################################# - ### Calculate cost ####################### + ##### Calculate cost using litellm.completion_cost() ####################### cost = litellm.completion_cost(completion_response=response_obj) response = response_obj # tokens used in response @@ -44,9 +43,7 @@ class MyCustomHandler(CustomLogger): Proxy Metadata: {metadata} """ ) - - print(usage) - + return def log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Failure") From 4ef0378e6ea7cbb80de4dd463ac3b41ff5faeee6 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 11:29:26 -0800 Subject: [PATCH 020/125] (fix) proxy: custom callbacks --- litellm/proxy/proxy_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1238fd8e2..ae3ec5298 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -564,7 +564,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): ) elif key == "callbacks": print(f"{blue_color_code}\nSetting custom callbacks on Proxy") - print() passed_module, instance_name = value.split(".") # Dynamically import the module @@ -574,11 +573,12 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): methods = [method for method in dir(instance) if callable(getattr(instance, method))] # Print the methods - print("Methods in the instance:") + print("Methods in the custom callbacks instance:") for method in methods: print(method) litellm.callbacks = [instance] + print() else: setattr(litellm, key, value) From 85ac2b179ad1c00f13e735621e33971413eb8038 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 11:30:09 -0800 Subject: [PATCH 021/125] (docs) custom logger: proxy --- docs/my-website/docs/proxy/logging.md | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index 851b8a4f1..33a72c3d9 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -1,4 +1,4 @@ -# Logging - Custom Callbacks, OpenTelemetry, Langfuse, ElasticSearch +# Logging - Custom Callbacks, OpenTelemetry, Langfuse Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry ## Custom Callbacks @@ -64,9 +64,7 @@ class MyCustomHandler(CustomLogger): proxy_handler_instance = MyCustomHandler() -# need to set litellm.callbacks = [customHandler] # on the proxy - - +# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy ``` ### Step 2 - Pass your custom callback class in `config.yaml` @@ -83,7 +81,7 @@ model_list: model: gpt-3.5-turbo litellm_settings: - callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [module.module_variable] + callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] ``` From 15f54c3072240eca8b615868d2a5f1702963ab16 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 11:30:53 -0800 Subject: [PATCH 022/125] (docs) default config proxy --- litellm/proxy/proxy_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 071a82205..5f4f9fcad 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -4,7 +4,7 @@ model_list: model: gpt-3.5-turbo litellm_settings: - callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [module.module_variable] + # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] general_settings: # otel: True # OpenTelemetry Logger From d0d8ba46c570726869a7dd376834076e0f794863 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Dec 2023 11:35:36 -0800 Subject: [PATCH 023/125] test(test_proxy_server_keys.py): fix relative import --- litellm/tests/test_proxy_server_keys.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/tests/test_proxy_server_keys.py b/litellm/tests/test_proxy_server_keys.py index b0e7d33bd..806b5f43e 100644 --- a/litellm/tests/test_proxy_server_keys.py +++ b/litellm/tests/test_proxy_server_keys.py @@ -24,8 +24,8 @@ logging.basicConfig( from fastapi.testclient import TestClient from fastapi import FastAPI from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined -cwd = os.getcwd() -config_fp = f"{cwd}/test_config.yaml" +filepath = os.path.dirname(os.path.abspath(__file__)) +config_fp = f"{filepath}/test_config.yaml" save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) app = FastAPI() app.include_router(router) # Include your router in the test app From 533b5bcc44a18c5f4ba8754cf3f6ec5736a0f096 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 11:37:14 -0800 Subject: [PATCH 024/125] (docs) clean up proxy logging --- docs/my-website/docs/proxy/logging.md | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index 33a72c3d9..4fea7a47e 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -28,23 +28,20 @@ class MyCustomHandler(CustomLogger): print(f"On Stream") def log_success_event(self, kwargs, response_obj, start_time, end_time): - # log: key, user, model, prompt, response, tokens, cost + # Logging key details: key, user, model, prompt, response, tokens, cost print("\nOn Success") - ### Access kwargs passed to litellm.completion() + # Access kwargs passed to litellm.completion() model = kwargs.get("model", None) messages = kwargs.get("messages", None) user = kwargs.get("user", None) - #### Access litellm_params passed to litellm.completion(), example access `metadata` + # Access litellm_params passed to litellm.completion(), example access `metadata` litellm_params = kwargs.get("litellm_params", {}) - metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here - ################################################# + metadata = litellm_params.get("metadata", {}) # Headers passed to LiteLLM proxy - ##### Calculate cost using litellm.completion_cost() ####################### + # Calculate cost using litellm.completion_cost() cost = litellm.completion_cost(completion_response=response_obj) - response = response_obj - # tokens used in response - usage = response_obj["usage"] + usage = response_obj["usage"] # Tokens used in response print( f""" @@ -64,6 +61,7 @@ class MyCustomHandler(CustomLogger): proxy_handler_instance = MyCustomHandler() +# Set litellm.callbacks = [proxy_handler_instance] on the proxy # need to set litellm.callbacks = [proxy_handler_instance] # on the proxy ``` From 333e77d161b4afb584a1374f33483c3f7d38ee9f Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 11:44:12 -0800 Subject: [PATCH 025/125] (docs) custom logger --- docs/my-website/docs/proxy/logging.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index 4fea7a47e..95f6a83a9 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -67,10 +67,14 @@ proxy_handler_instance = MyCustomHandler() ### Step 2 - Pass your custom callback class in `config.yaml` We pass the custom callback class defined in **Step1** to the config.yaml. - Set `callbacks` to `python_filename.logger_instance_name` -In the config below, the custom callback is defined in a file`custom_callbacks.py` and has an instance of `proxy_handler_instance = MyCustomHandler()`. +In the config below, we pass +- python_filename: `custom_callbacks.py` +- logger_instance_name: `proxy_handler_instance`. This is defined in Step 1 + +`callbacks: custom_callbacks.proxy_handler_instance` + ```yaml model_list: From 6f8765125b28289e16bbeb890899e527f16e3480 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 11:57:41 -0800 Subject: [PATCH 026/125] (docs) embedding: add api_base for HF --- docs/my-website/docs/embedding/supported_embedding.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/my-website/docs/embedding/supported_embedding.md b/docs/my-website/docs/embedding/supported_embedding.md index 009e5337b..3c13c5732 100644 --- a/docs/my-website/docs/embedding/supported_embedding.md +++ b/docs/my-website/docs/embedding/supported_embedding.md @@ -182,6 +182,17 @@ response = embedding( input=["good morning from litellm"] ) ``` +### Usage - Custom API Base +```python +from litellm import embedding +import os +os.environ['HUGGINGFACE_API_KEY'] = "" +response = embedding( + model='huggingface/microsoft/codebert-base', + input=["good morning from litellm"], + api_base = "https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud" +) +``` | Model Name | Function Call | Required OS Variables | |-----------------------|--------------------------------------------------------------|-------------------------------------------------| From 90c13d39acd8fd980892c0a271de3b204ae26335 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Mon, 4 Dec 2023 12:01:22 -0800 Subject: [PATCH 027/125] Updated config.yml --- .circleci/config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 13fe3d973..8f2a89846 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -35,6 +35,7 @@ jobs: pip install numpydoc pip install traceloop-sdk==0.0.69 pip install openai + pip install prisma - save_cache: paths: - ./venv From 728b879c33f85a6844c4c07ed4d6ccea5b2bb2fc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Dec 2023 12:38:15 -0800 Subject: [PATCH 028/125] fix(utils.py): fix azure streaming bug --- litellm/tests/test_config.yaml | 17 ++++++++++++ litellm/tests/test_proxy_server_cost.py | 27 +++++++++++++++++++ litellm/tests/test_stream_chunk_builder.py | 24 ++++++++++++++++- litellm/utils.py | 30 +++++++++++++++++----- 4 files changed, 91 insertions(+), 7 deletions(-) create mode 100644 litellm/tests/test_proxy_server_cost.py diff --git a/litellm/tests/test_config.yaml b/litellm/tests/test_config.yaml index 0e678d2d3..a38dc7615 100644 --- a/litellm/tests/test_config.yaml +++ b/litellm/tests/test_config.yaml @@ -1,3 +1,20 @@ +model_list: + - model_name: "azure-model" + litellm_params: + model: "azure/gpt-35-turbo" + api_key: "os.environ/AZURE_EUROPE_API_KEY" + api_base: "https://my-endpoint-europe-berri-992.openai.azure.com/" + - model_name: "azure-model" + litellm_params: + model: "azure/gpt-35-turbo" + api_key: "os.environ/AZURE_CANADA_API_KEY" + api_base: "https://my-endpoint-canada-berri992.openai.azure.com" + - model_name: "azure-model" + litellm_params: + model: "azure/gpt-turbo" + api_key: "os.environ/AZURE-FRANCE-API-KEY" + api_base: "https://openai-france-1234.openai.azure.com" + litellm_settings: drop_params: True set_verbose: True diff --git a/litellm/tests/test_proxy_server_cost.py b/litellm/tests/test_proxy_server_cost.py new file mode 100644 index 000000000..7688e5899 --- /dev/null +++ b/litellm/tests/test_proxy_server_cost.py @@ -0,0 +1,27 @@ +# #### What this tests #### +# # This tests the cost tracking function works with consecutive calls (~10 consecutive calls) + +# import sys, os +# import traceback +# import pytest +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import litellm + +# async def test_proxy_cost_tracking(): +# """ +# Get expected cost. +# Create new key. +# Run 10 parallel calls. +# Check cost for key at the end. +# assert it's = expected cost. +# """ +# model = "gpt-3.5-turbo" +# messages = [{"role": "user", "content": "Hey, how's it going?"}] +# number_of_calls = 10 +# expected_cost = litellm.completion_cost(model=model, messages=messages) * number_of_calls +# async def litellm_acompletion(): + + + diff --git a/litellm/tests/test_stream_chunk_builder.py b/litellm/tests/test_stream_chunk_builder.py index 807e74cfb..23f67a2e8 100644 --- a/litellm/tests/test_stream_chunk_builder.py +++ b/litellm/tests/test_stream_chunk_builder.py @@ -110,4 +110,26 @@ def test_stream_chunk_builder_litellm_tool_call(): except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") -test_stream_chunk_builder_litellm_tool_call() +# test_stream_chunk_builder_litellm_tool_call() + +def test_stream_chunk_builder_litellm_tool_call_regular_message(): + try: + messages = [{"role": "user", "content": "Hey, how's it going?"}] + litellm.set_verbose = False + response = litellm.completion( + model="azure/gpt-4-nov-release", + messages=messages, + tools=tools_schema, + stream=True, + api_key="os.environ/AZURE_FRANCE_API_KEY", + api_base="https://openai-france-1234.openai.azure.com", + complete_response = True + ) + + print(f"complete response: {response}") + print(f"complete response usage: {response.usage}") + assert response.system_fingerprint is not None + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}") + +test_stream_chunk_builder_litellm_tool_call_regular_message() diff --git a/litellm/utils.py b/litellm/utils.py index dce259237..3756337b6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5266,14 +5266,32 @@ class CustomStreamWrapper: print_verbose(f"model_response: {model_response}; completion_obj: {completion_obj}") print_verbose(f"model_response finish reason 3: {model_response.choices[0].finish_reason}") if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string - hold, model_response_str = self.check_special_tokens(chunk=completion_obj["content"], finish_reason=model_response.choices[0].finish_reason) + hold, model_response_str = self.check_special_tokens(chunk=completion_obj["content"], finish_reason=model_response.choices[0].finish_reason) # filter out bos/eos tokens from openai-compatible hf endpoints print_verbose(f"hold - {hold}, model_response_str - {model_response_str}") if hold is False: - completion_obj["content"] = model_response_str - if self.sent_first_chunk == False: - completion_obj["role"] = "assistant" - self.sent_first_chunk = True - model_response.choices[0].delta = Delta(**completion_obj) + ## check if openai/azure chunk + original_chunk = response_obj.get("original_chunk", None) + if original_chunk: + model_response.id = original_chunk.id + if len(original_chunk.choices) > 0: + try: + delta = dict(original_chunk.choices[0].delta) + model_response.choices[0].delta = Delta(**delta) + except Exception as e: + model_response.choices[0].delta = Delta() + else: + return + model_response.system_fingerprint = original_chunk.system_fingerprint + if self.sent_first_chunk == False: + model_response.choices[0].delta["role"] = "assistant" + self.sent_first_chunk = True + else: + ## else + completion_obj["content"] = model_response_str + if self.sent_first_chunk == False: + completion_obj["role"] = "assistant" + self.sent_first_chunk = True + model_response.choices[0].delta = Delta(**completion_obj) # LOGGING threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() print_verbose(f"model_response: {model_response}") From f20bdc9b79a07b0a505b0ff08ca46c0c449e6b9a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Dec 2023 12:45:15 -0800 Subject: [PATCH 029/125] test: fix linting errors --- litellm/proxy/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 7acdd9b4b..1ea7f47a0 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -18,7 +18,7 @@ class PrismaClient: finally: os.chdir(original_dir) # Now you can import the Prisma Client - from prisma import Client + from prisma import Client # type: ignore self.db = Client() #Client to connect to Prisma db def hash_token(self, token: str): From de4a7b719d507259e6cd99d07d2c6a83215b1bd9 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 13:16:19 -0800 Subject: [PATCH 030/125] (test) proxy: reading config.yaml --- .../proxy/example_config_yaml/simple_config.yaml | 4 ++++ litellm/tests/test_proxy_server.py | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 litellm/proxy/example_config_yaml/simple_config.yaml diff --git a/litellm/proxy/example_config_yaml/simple_config.yaml b/litellm/proxy/example_config_yaml/simple_config.yaml new file mode 100644 index 000000000..14b39a125 --- /dev/null +++ b/litellm/proxy/example_config_yaml/simple_config.yaml @@ -0,0 +1,4 @@ +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: gpt-3.5-turbo \ No newline at end of file diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 805e2a860..858277dee 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -164,4 +164,17 @@ def test_chat_completion_optional_params(): pytest.fail("LiteLLM Proxy test failed. Exception", e) # Run the test -test_chat_completion_optional_params() +# test_chat_completion_optional_params() + +# Test Reading config.yaml file +from litellm.proxy.proxy_server import load_router_config + +def test_load_router_config(): + try: + print("testing reading config") + result = load_router_config(router=None, config_file_path="../proxy/example_config_yaml/simple_config.yaml") + print(result) + assert len(result[1]) == 1 + except Exception as e: + pytest.fail("Proxy: Got exception reading config", e) +# test_load_router_config() From a99f471d29db0ef788ea8952fd746b82f661a6a0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 4 Dec 2023 13:20:23 -0800 Subject: [PATCH 031/125] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7fad2d9da..cc5c1a3a5 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@

Call all LLM APIs using the OpenAI format [Bedrock, Huggingface, Cohere, TogetherAI, Azure, OpenAI, etc.]

-

OpenAI-Compatible Server

+

OpenAI Proxy Server

PyPI Version From 50284771b7cbbef0197ee1ac19f14ccae29ba40c Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 13:24:35 -0800 Subject: [PATCH 032/125] (test) test_reading proxy --- litellm/proxy/example_config_yaml/azure_config.yaml | 10 ++++------ litellm/tests/test_proxy_server.py | 10 +++++++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/example_config_yaml/azure_config.yaml b/litellm/proxy/example_config_yaml/azure_config.yaml index 14e4a786f..fd5865cd7 100644 --- a/litellm/proxy/example_config_yaml/azure_config.yaml +++ b/litellm/proxy/example_config_yaml/azure_config.yaml @@ -4,14 +4,12 @@ model_list: model: azure/chatgpt-v-2 api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" - azure_ad_token: eyJ0eXAiOiJ + api_key: os.environ/AZURE_API_KEY + tpm: 20_000 - model_name: gpt-4-team2 litellm_params: model: azure/gpt-4 - api_key: sk-123 + api_key: os.environ/AZURE_API_KEY api_base: https://openai-gpt-4-test-v-2.openai.azure.com/ - - model_name: gpt-4-team3 - litellm_params: - model: azure/gpt-4 - api_key: sk-123 + tpm: 100_000 diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 858277dee..189cf6083 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -172,9 +172,17 @@ from litellm.proxy.proxy_server import load_router_config def test_load_router_config(): try: print("testing reading config") + # this is a basic config.yaml with only a model result = load_router_config(router=None, config_file_path="../proxy/example_config_yaml/simple_config.yaml") print(result) assert len(result[1]) == 1 + + # this is a load balancing config yaml + result = load_router_config(router=None, config_file_path="../proxy/example_config_yaml/azure_config.yaml") + print(result) + assert len(result[1]) == 2 + + except Exception as e: pytest.fail("Proxy: Got exception reading config", e) -# test_load_router_config() +test_load_router_config() From 07a20356516cf2eedddb2a4c22147fa9b32f6b93 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 13:26:55 -0800 Subject: [PATCH 033/125] (chore) rm old config examples --- .../example_config_yaml/azure_config_with_tpm.yaml | 13 ------------- .../proxy/example_config_yaml/hosted_litellm.yaml | 8 -------- litellm/proxy/example_config_yaml/queue.yaml | 11 ----------- 3 files changed, 32 deletions(-) delete mode 100644 litellm/proxy/example_config_yaml/azure_config_with_tpm.yaml delete mode 100644 litellm/proxy/example_config_yaml/hosted_litellm.yaml delete mode 100644 litellm/proxy/example_config_yaml/queue.yaml diff --git a/litellm/proxy/example_config_yaml/azure_config_with_tpm.yaml b/litellm/proxy/example_config_yaml/azure_config_with_tpm.yaml deleted file mode 100644 index c1e3e0506..000000000 --- a/litellm/proxy/example_config_yaml/azure_config_with_tpm.yaml +++ /dev/null @@ -1,13 +0,0 @@ -model_list: - - model_name: gpt-3.5-turbo - litellm_params: - model: azure/gpt-35-1 - api_base: https://my-endpoint-canada-berri992.openai.azure.com/ - api_key: 73g - tpm: 80_000 - - model_name: gpt-3.5-turbo - litellm_params: - model: azure/gpt-35-2 - api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ - api_key: 9kj - tpm: 80_000 \ No newline at end of file diff --git a/litellm/proxy/example_config_yaml/hosted_litellm.yaml b/litellm/proxy/example_config_yaml/hosted_litellm.yaml deleted file mode 100644 index 1ede81830..000000000 --- a/litellm/proxy/example_config_yaml/hosted_litellm.yaml +++ /dev/null @@ -1,8 +0,0 @@ - -litellm_settings: - set_verbose: True - -general_settings: - master_key: sk-hosted-litellm - use_queue: True - database_url: " # [OPTIONAL] use for token-based auth to proxy diff --git a/litellm/proxy/example_config_yaml/queue.yaml b/litellm/proxy/example_config_yaml/queue.yaml deleted file mode 100644 index 897586669..000000000 --- a/litellm/proxy/example_config_yaml/queue.yaml +++ /dev/null @@ -1,11 +0,0 @@ -model_list: - - model_name: gpt-3.5-turbo - litellm_params: - model: gpt-3.5-turbo - api_key: - - model_name: gpt-3.5-turbo - litellm_params: - model: azure/chatgpt-v-2 # actual model name - api_key: - api_version: 2023-07-01-preview - api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ \ No newline at end of file From 32ecc1a6775afc5973a46a053cd415845e0be6fc Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 13:43:03 -0800 Subject: [PATCH 034/125] (feat) replicate/deployments: add POST Req view --- litellm/llms/replicate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index d639a8d1e..0f8e23e2a 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -100,7 +100,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p logging_obj.pre_call( input=input_data["prompt"], api_key="", - additional_args={"complete_input_dict": initial_prediction_data, "headers": headers}, + additional_args={"complete_input_dict": initial_prediction_data, "headers": headers, "api_base": base_url}, ) response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers) From 533b93f7144cc968a66487e31fd5d5f1e228ae93 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 14:49:59 -0800 Subject: [PATCH 035/125] (test) proxy: reading configs --- .../example_config_yaml/aliases_config.yaml | 30 +++++++++++++++++++ .../example_config_yaml/azure_config.yaml | 15 ++++++++++ .../example_config_yaml/langfuse_config.yaml | 7 +++++ .../example_config_yaml/load_balancer.yaml | 28 +++++++++++++++++ .../opentelemetry_config.yaml | 7 +++++ .../example_config_yaml/simple_config.yaml | 4 +++ litellm/tests/test_proxy_server.py | 10 +++++-- 7 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 litellm/tests/example_config_yaml/aliases_config.yaml create mode 100644 litellm/tests/example_config_yaml/azure_config.yaml create mode 100644 litellm/tests/example_config_yaml/langfuse_config.yaml create mode 100644 litellm/tests/example_config_yaml/load_balancer.yaml create mode 100644 litellm/tests/example_config_yaml/opentelemetry_config.yaml create mode 100644 litellm/tests/example_config_yaml/simple_config.yaml diff --git a/litellm/tests/example_config_yaml/aliases_config.yaml b/litellm/tests/example_config_yaml/aliases_config.yaml new file mode 100644 index 000000000..266f6cf22 --- /dev/null +++ b/litellm/tests/example_config_yaml/aliases_config.yaml @@ -0,0 +1,30 @@ +model_list: + - model_name: text-davinci-003 + litellm_params: + model: ollama/zephyr + - model_name: gpt-4 + litellm_params: + model: ollama/llama2 + - model_name: gpt-3.5-turbo + litellm_params: + model: ollama/llama2 + temperature: 0.1 + max_tokens: 20 + + +# request to gpt-4, response from ollama/llama2 +# curl --location 'http://0.0.0.0:8000/chat/completions' \ +# --header 'Content-Type: application/json' \ +# --data ' { +# "model": "gpt-4", +# "messages": [ +# { +# "role": "user", +# "content": "what llm are you" +# } +# ], +# } +# ' +# + +# {"id":"chatcmpl-27c85cf0-ab09-4bcf-8cb1-0ee950520743","choices":[{"finish_reason":"stop","index":0,"message":{"content":" Hello! I'm just an AI, I don't have personal experiences or emotions like humans do. However, I can help you with any questions or tasks you may have! Is there something specific you'd like to know or discuss?","role":"assistant","_logprobs":null}}],"created":1700094955.373751,"model":"ollama/llama2","object":"chat.completion","system_fingerprint":null,"usage":{"prompt_tokens":12,"completion_tokens":47,"total_tokens":59},"_response_ms":8028.017999999999}% \ No newline at end of file diff --git a/litellm/tests/example_config_yaml/azure_config.yaml b/litellm/tests/example_config_yaml/azure_config.yaml new file mode 100644 index 000000000..fd5865cd7 --- /dev/null +++ b/litellm/tests/example_config_yaml/azure_config.yaml @@ -0,0 +1,15 @@ +model_list: + - model_name: gpt-4-team1 + litellm_params: + model: azure/chatgpt-v-2 + api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ + api_version: "2023-05-15" + api_key: os.environ/AZURE_API_KEY + tpm: 20_000 + - model_name: gpt-4-team2 + litellm_params: + model: azure/gpt-4 + api_key: os.environ/AZURE_API_KEY + api_base: https://openai-gpt-4-test-v-2.openai.azure.com/ + tpm: 100_000 + diff --git a/litellm/tests/example_config_yaml/langfuse_config.yaml b/litellm/tests/example_config_yaml/langfuse_config.yaml new file mode 100644 index 000000000..c2a77b5ad --- /dev/null +++ b/litellm/tests/example_config_yaml/langfuse_config.yaml @@ -0,0 +1,7 @@ +model_list: + - model_name: gpt-3.5-turbo + +litellm_settings: + drop_params: True + success_callback: ["langfuse"] # https://docs.litellm.ai/docs/observability/langfuse_integration + diff --git a/litellm/tests/example_config_yaml/load_balancer.yaml b/litellm/tests/example_config_yaml/load_balancer.yaml new file mode 100644 index 000000000..502b90ff9 --- /dev/null +++ b/litellm/tests/example_config_yaml/load_balancer.yaml @@ -0,0 +1,28 @@ +litellm_settings: + drop_params: True + +# Model-specific settings +model_list: # use the same model_name for using the litellm router. LiteLLM will use the router between gpt-3.5-turbo + - model_name: gpt-3.5-turbo # litellm will + litellm_params: + model: gpt-3.5-turbo + api_key: sk-uj6F + tpm: 20000 # [OPTIONAL] REPLACE with your openai tpm + rpm: 3 # [OPTIONAL] REPLACE with your openai rpm + - model_name: gpt-3.5-turbo + litellm_params: + model: gpt-3.5-turbo + api_key: sk-Imn + tpm: 20000 # [OPTIONAL] REPLACE with your openai tpm + rpm: 3 # [OPTIONAL] REPLACE with your openai rpm + - model_name: gpt-3.5-turbo + litellm_params: + model: openrouter/gpt-3.5-turbo + - model_name: mistral-7b-instruct + litellm_params: + model: mistralai/mistral-7b-instruct + +environment_variables: + REDIS_HOST: localhost + REDIS_PASSWORD: + REDIS_PORT: \ No newline at end of file diff --git a/litellm/tests/example_config_yaml/opentelemetry_config.yaml b/litellm/tests/example_config_yaml/opentelemetry_config.yaml new file mode 100644 index 000000000..92d3454d7 --- /dev/null +++ b/litellm/tests/example_config_yaml/opentelemetry_config.yaml @@ -0,0 +1,7 @@ +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: gpt-3.5-turbo + +general_settings: + otel: True # OpenTelemetry Logger this logs OTEL data to your collector diff --git a/litellm/tests/example_config_yaml/simple_config.yaml b/litellm/tests/example_config_yaml/simple_config.yaml new file mode 100644 index 000000000..14b39a125 --- /dev/null +++ b/litellm/tests/example_config_yaml/simple_config.yaml @@ -0,0 +1,4 @@ +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: gpt-3.5-turbo \ No newline at end of file diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 189cf6083..4a10ecbf8 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -173,16 +173,20 @@ def test_load_router_config(): try: print("testing reading config") # this is a basic config.yaml with only a model - result = load_router_config(router=None, config_file_path="../proxy/example_config_yaml/simple_config.yaml") + result = load_router_config(router=None, config_file_path="example_config_yaml/simple_config.yaml") print(result) assert len(result[1]) == 1 # this is a load balancing config yaml - result = load_router_config(router=None, config_file_path="../proxy/example_config_yaml/azure_config.yaml") + result = load_router_config(router=None, config_file_path="example_config_yaml/azure_config.yaml") print(result) assert len(result[1]) == 2 + # config with general settings - custom callbacks + result = load_router_config(router=None, config_file_path="example_config_yaml/azure_config.yaml") + print(result) + assert len(result[1]) == 2 except Exception as e: pytest.fail("Proxy: Got exception reading config", e) -test_load_router_config() +# test_load_router_config() From 3a4e512a7571d423791335964e4a7f5198a4443c Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 15:06:52 -0800 Subject: [PATCH 036/125] (fix) palm: streaming --- litellm/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/utils.py b/litellm/utils.py index 3756337b6..9505addf5 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5232,6 +5232,7 @@ class CustomStreamWrapper: time.sleep(0.05) elif self.custom_llm_provider == "palm": # fake streaming + response_obj = {} if len(self.completion_stream)==0: if self.sent_last_chunk: raise StopIteration From 41365b6e475e3298c5ac70148431aaea960023c1 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 15:07:07 -0800 Subject: [PATCH 037/125] (test) palm/stream --- litellm/tests/test_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 54bc53326..484586325 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -505,7 +505,7 @@ def test_completion_openai_with_optional_params(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_openai_with_optional_params() +# test_completion_openai_with_optional_params() def test_completion_openai_litellm_key(): try: From 74d520b1b53b1bd29ebf30ceff69873385861287 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 15:15:39 -0800 Subject: [PATCH 038/125] (docs) sagemaker - clarify max tokens --- docs/my-website/docs/completion/input.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index 30bd20f84..02e3d38d7 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -40,7 +40,7 @@ This list is constantly being updated. |AI21| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | |VertexAI| ✅ | ✅ | | ✅ | | | | | | | |Bedrock| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | -|Sagemaker| ✅ | ✅ | | ✅ | | | | | | | +|Sagemaker| ✅ | ✅ (only `jumpstart llama2`) | | ✅ | | | | | | | |TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |AlephAlpha| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |Palm| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | From bc691cbbcda1fafe8fe1b930c7d14e017484bb1e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 15:19:47 -0800 Subject: [PATCH 039/125] (fix) streaming init response_obj as {} --- litellm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index 9505addf5..d3a9b8bb0 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5125,7 +5125,7 @@ class CustomStreamWrapper: def chunk_creator(self, chunk): model_response = ModelResponse(stream=True, model=self.model) model_response.choices[0].finish_reason = None - response_obj = None + response_obj = {} try: # return this for all models completion_obj = {"content": ""} From b7281825d3c65c220178b6e33a84f09ffb8be59a Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 15:20:06 -0800 Subject: [PATCH 040/125] (test) add streaming sagemaker test --- litellm/tests/test_completion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 484586325..15b547a85 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1053,9 +1053,12 @@ def test_completion_chat_sagemaker(): response = completion( model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f", messages=messages, + stream=True, ) # Add any assertions here to check the response print(response) + for chunk in response: + print(chunk) except Exception as e: pytest.fail(f"Error occurred: {e}") # test_completion_chat_sagemaker() From 9b3a0c69f53f5a58879ba11de3b7c3a4f072cc13 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 15:24:46 -0800 Subject: [PATCH 041/125] (fix) config testing --- litellm/tests/test_proxy_server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 4a10ecbf8..a525f01bf 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -173,17 +173,18 @@ def test_load_router_config(): try: print("testing reading config") # this is a basic config.yaml with only a model - result = load_router_config(router=None, config_file_path="example_config_yaml/simple_config.yaml") + filepath = os.path.dirname(os.path.abspath(__file__)) + result = load_router_config(router=None, config_file_path=f"{filepath}/example_config_yaml/simple_config.yaml") print(result) assert len(result[1]) == 1 # this is a load balancing config yaml - result = load_router_config(router=None, config_file_path="example_config_yaml/azure_config.yaml") + result = load_router_config(router=None, config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml") print(result) assert len(result[1]) == 2 # config with general settings - custom callbacks - result = load_router_config(router=None, config_file_path="example_config_yaml/azure_config.yaml") + result = load_router_config(router=None, config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml") print(result) assert len(result[1]) == 2 From bbdfd143b899ea961f5cb611f38ad244c7afa743 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 15:47:27 -0800 Subject: [PATCH 042/125] (docs) input --- docs/my-website/docs/completion/input.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index 02e3d38d7..7902275ab 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -6,7 +6,7 @@ import TabItem from '@theme/TabItem'; ## Common Params LiteLLM accepts and translates the [OpenAI Chat Completion params](https://platform.openai.com/docs/api-reference/chat/create) across all providers. -### usage +### Usage ```python import litellm @@ -23,7 +23,7 @@ response = litellm.completion( print(response) ``` -### translated OpenAI params +### Translated OpenAI params This is a list of openai params we translate across providers. This list is constantly being updated. From a9905bcd0a482995c6a3c564e36270335e3b4126 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 16:00:04 -0800 Subject: [PATCH 043/125] (test) fix config --- litellm/tests/test_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/tests/test_config.yaml b/litellm/tests/test_config.yaml index a38dc7615..34b3d928a 100644 --- a/litellm/tests/test_config.yaml +++ b/litellm/tests/test_config.yaml @@ -12,7 +12,7 @@ model_list: - model_name: "azure-model" litellm_params: model: "azure/gpt-turbo" - api_key: "os.environ/AZURE-FRANCE-API-KEY" + api_key: "os.environ/AZURE_FRANCE_API_KEY" api_base: "https://openai-france-1234.openai.azure.com" litellm_settings: From 1247afb7a4219664b999cc820e710c79f0357f6b Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 16:09:31 -0800 Subject: [PATCH 044/125] (feat) router: set max_retries + timeout --- litellm/router.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index fad67cf6d..5ce0d409b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -857,6 +857,23 @@ class Router: if api_version and api_version.startswith("os.environ/"): api_version_env_name = api_version.replace("os.environ/", "") api_version = litellm.get_secret(api_version_env_name) + + timeout = litellm_params.get("timeout") + if timeout and timeout.startswith("os.environ/"): + timeout_env_name = api_version.replace("os.environ/", "") + timeout = litellm.get_secret(timeout_env_name) + + stream_timeout = litellm_params.get("stream_timeout") + if stream_timeout and stream_timeout.startswith("os.environ/"): + stream_timeout_env_name = api_version.replace("os.environ/", "") + stream_timeout = litellm.get_secret(stream_timeout_env_name) + + max_retries = litellm_params.get("max_retries") + if max_retries and max_retries.startswith("os.environ/"): + max_retries_env_name = api_version.replace("os.environ/", "") + max_retries = litellm.get_secret(max_retries_env_name) + + self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}") if "azure" in model_name: if api_version is None: @@ -869,32 +886,44 @@ class Router: model["async_client"] = openai.AsyncAzureOpenAI( api_key=api_key, base_url=api_base, - api_version=api_version + api_version=api_version, + timeout=timeout, + max_retries=max_retries ) model["client"] = openai.AzureOpenAI( api_key=api_key, base_url=api_base, - api_version=api_version + api_version=api_version, + timeout=timeout, + max_retries=max_retries ) else: model["async_client"] = openai.AsyncAzureOpenAI( api_key=api_key, azure_endpoint=api_base, - api_version=api_version + api_version=api_version, + timeout=timeout, + max_retries=max_retries ) model["client"] = openai.AzureOpenAI( api_key=api_key, azure_endpoint=api_base, - api_version=api_version + api_version=api_version, + timeout=timeout, + max_retries=max_retries ) else: model["async_client"] = openai.AsyncOpenAI( api_key=api_key, base_url=api_base, + timeout=timeout, + max_retries=max_retries ) model["client"] = openai.OpenAI( api_key=api_key, base_url=api_base, + timeout=timeout, + max_retries=max_retries ) ############ End of initializing Clients for OpenAI/Azure ################### model_id = "" From e0ccb281d8941ee0992bc26c4a5ad48616125eaf Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Dec 2023 16:36:21 -0800 Subject: [PATCH 045/125] feat(utils.py): add async success callbacks for custom functions --- litellm/__init__.py | 1 + litellm/integrations/custom_logger.py | 22 ++++- litellm/proxy/proxy_server.py | 44 +++------ litellm/router.py | 2 + litellm/tests/test_custom_logger.py | 113 +++++---------------- litellm/tests/test_proxy_server_cost.py | 125 ++++++++++++++++++++++-- litellm/tests/test_proxy_server_keys.py | 1 + litellm/utils.py | 62 ++++++++++-- 8 files changed, 232 insertions(+), 138 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index b494268ad..b9cf85a55 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -8,6 +8,7 @@ input_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = [] callbacks: List[Callable] = [] +_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] set_verbose = False diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index af3ea050f..e502439a9 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -8,7 +8,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv import traceback -class CustomLogger: +class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class # Class variables or attributes def __init__(self): pass @@ -29,7 +29,7 @@ class CustomLogger: pass - #### DEPRECATED #### + #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): try: @@ -63,3 +63,21 @@ class CustomLogger: # traceback.print_exc() print_verbose(f"Custom Logger Error - {traceback.format_exc()}") pass + + async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func): + # Method definition + try: + kwargs["log_event_type"] = "post_api_call" + await callback_func( + kwargs, # kwargs to func + response_obj, + start_time, + end_time, + ) + print_verbose( + f"Custom Logger - final response object: {response_obj}" + ) + except: + # traceback.print_exc() + print_verbose(f"Custom Logger Error - {traceback.format_exc()}") + pass diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ae3ec5298..8e9ddc9fa 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -272,10 +272,16 @@ api_key_header = APIKeyHeader(name="Authorization", auto_error=False) async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)): global master_key, prisma_client, llm_model_list + print(f"master_key - {master_key}; api_key - {api_key}") if master_key is None: - return { - "api_key": None - } + if isinstance(api_key, str): + return { + "api_key": api_key.replace("Bearer ", "") + } + else: + return { + "api_key": api_key + } try: if api_key is None: raise Exception("No api key passed in.") @@ -382,8 +388,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False): print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`") def cost_tracking(): - global prisma_client, master_key - if prisma_client is not None and master_key is not None: + global prisma_client + if prisma_client is not None: if isinstance(litellm.success_callback, list): print("setting litellm success callback to track cost") if (track_cost_callback) not in litellm.success_callback: # type: ignore @@ -391,7 +397,7 @@ def cost_tracking(): else: litellm.success_callback = track_cost_callback # type: ignore -def track_cost_callback( +async def track_cost_callback( kwargs, # kwargs to completion completion_response: litellm.ModelResponse, # response from completion start_time = None, @@ -420,31 +426,13 @@ def track_cost_callback( response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text) print("regular response_cost", response_cost) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) + print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}") if user_api_key and prisma_client: - # asyncio.run(update_prisma_database(user_api_key, response_cost)) - # Create new event loop for async function execution in the new thread - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - try: - # Run the async function using the newly created event loop - existing_spend_obj = new_loop.run_until_complete(prisma_client.get_data(token=user_api_key)) - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost - print(f"new cost: {new_spend}") - # Update the cost column for the given token - new_loop.run_until_complete(prisma_client.update_data(token=user_api_key, data={"spend": new_spend})) - print(f"Prisma database updated for token {user_api_key}. New cost: {new_spend}") - except Exception as e: - print(f"error in creating async loop - {str(e)}") + await update_prisma_database(token=user_api_key, response_cost=response_cost) except Exception as e: print(f"error in tracking cost callback - {str(e)}") async def update_prisma_database(token, response_cost): - try: print(f"Enters prisma db call, token: {token}") # Fetch the existing cost for the given token @@ -460,8 +448,6 @@ async def update_prisma_database(token, response_cost): print(f"new cost: {new_spend}") # Update the cost column for the given token await prisma_client.update_data(token=token, data={"spend": new_spend}) - print(f"Prisma database updated for token {token}. New cost: {new_spend}") - except Exception as e: print(f"Error updating Prisma database: {traceback.format_exc()}") pass @@ -648,7 +634,7 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia except Exception as e: traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) - return {"token": new_verification_token.token, "expires": new_verification_token.expires, "user_id": user_id} + return {"token": token, "expires": new_verification_token.expires, "user_id": user_id} async def delete_verification_token(tokens: List): global prisma_client diff --git a/litellm/router.py b/litellm/router.py index 5ce0d409b..5bf06760e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -876,6 +876,7 @@ class Router: self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}") if "azure" in model_name: + self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, {str(api_base)}, {api_key}") if api_version is None: api_version = "2023-07-01-preview" if "gateway.ai.cloudflare.com" in api_base: @@ -913,6 +914,7 @@ class Router: max_retries=max_retries ) else: + self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}") model["async_client"] = openai.AsyncOpenAI( api_key=api_key, base_url=api_base, diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index 7e134bd26..f88bc6868 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -1,5 +1,5 @@ ### What this tests #### -import sys, os, time +import sys, os, time, inspect, asyncio import pytest sys.path.insert(0, os.path.abspath('../..')) @@ -7,6 +7,7 @@ from litellm import completion, embedding import litellm from litellm.integrations.custom_logger import CustomLogger +async_success = False class MyCustomHandler(CustomLogger): success: bool = False failure: bool = False @@ -28,24 +29,29 @@ class MyCustomHandler(CustomLogger): print(f"On Failure") self.failure = True -# def test_chat_openai(): -# try: -# customHandler = MyCustomHandler() -# litellm.callbacks = [customHandler] -# response = completion(model="gpt-3.5-turbo", -# messages=[{ -# "role": "user", -# "content": "Hi 👋 - i'm openai" -# }], -# stream=True) -# time.sleep(1) -# assert customHandler.success == True -# except Exception as e: -# pytest.fail(f"An error occurred - {str(e)}") -# pass +async def async_test_logging_fn(kwargs, completion_obj, start_time, end_time): + global async_success + print(f"ON ASYNC LOGGING") + async_success = True -# test_chat_openai() +@pytest.mark.asyncio +async def test_chat_openai(): + try: + # litellm.set_verbose = True + litellm.success_callback = [async_test_logging_fn] + response = await litellm.acompletion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm openai" + }], + stream=True) + async for chunk in response: + continue + assert async_success == True + except Exception as e: + print(e) + pytest.fail(f"An error occurred - {str(e)}") def test_completion_azure_stream_moderation_failure(): try: @@ -71,76 +77,3 @@ def test_completion_azure_stream_moderation_failure(): assert customHandler.failure == True except Exception as e: pytest.fail(f"Error occurred: {e}") - -# test_completion_azure_stream_moderation_failure() - - -# def custom_callback( -# kwargs, -# completion_response, -# start_time, -# end_time, -# ): -# print( -# "in custom callback func" -# ) -# print("kwargs", kwargs) -# print(completion_response) -# print(start_time) -# print(end_time) -# if "complete_streaming_response" in kwargs: -# print("\n\n complete response\n\n") -# complete_streaming_response = kwargs["complete_streaming_response"] -# print(kwargs["complete_streaming_response"]) -# usage = complete_streaming_response["usage"] -# print("usage", usage) -# def send_slack_alert( -# kwargs, -# completion_response, -# start_time, -# end_time, -# ): -# print( -# "in custom slack callback func" -# ) -# import requests -# import json - -# # Define the Slack webhook URL -# slack_webhook_url = os.environ['SLACK_WEBHOOK_URL'] # "https://hooks.slack.com/services/<>/<>/<>" - -# # Define the text payload, send data available in litellm custom_callbacks -# text_payload = f"""LiteLLM Logging: kwargs: {str(kwargs)}\n\n, response: {str(completion_response)}\n\n, start time{str(start_time)} end time: {str(end_time)} -# """ -# payload = { -# "text": text_payload -# } - -# # Set the headers -# headers = { -# "Content-type": "application/json" -# } - -# # Make the POST request -# response = requests.post(slack_webhook_url, json=payload, headers=headers) - -# # Check the response status -# if response.status_code == 200: -# print("Message sent successfully to Slack!") -# else: -# print(f"Failed to send message to Slack. Status code: {response.status_code}") -# print(response.json()) - -# def get_transformed_inputs( -# kwargs, -# ): -# params_to_model = kwargs["additional_args"]["complete_input_dict"] -# print("params to model", params_to_model) - -# litellm.success_callback = [custom_callback, send_slack_alert] -# litellm.failure_callback = [send_slack_alert] - - -# litellm.set_verbose = False - -# # litellm.input_callback = [get_transformed_inputs] diff --git a/litellm/tests/test_proxy_server_cost.py b/litellm/tests/test_proxy_server_cost.py index 7688e5899..b127e72e3 100644 --- a/litellm/tests/test_proxy_server_cost.py +++ b/litellm/tests/test_proxy_server_cost.py @@ -1,27 +1,138 @@ # #### What this tests #### # # This tests the cost tracking function works with consecutive calls (~10 consecutive calls) -# import sys, os +# import sys, os, asyncio # import traceback # import pytest # sys.path.insert( # 0, os.path.abspath("../..") # ) # Adds the parent directory to the system path +# import dotenv +# dotenv.load_dotenv() # import litellm +# from fastapi.testclient import TestClient +# from fastapi import FastAPI +# from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined +# filepath = os.path.dirname(os.path.abspath(__file__)) +# config_fp = f"{filepath}/test_config.yaml" +# save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=True, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) +# app = FastAPI() +# app.include_router(router) # Include your router in the test app +# @app.on_event("startup") +# async def wrapper_startup_event(): +# await startup_event() -# async def test_proxy_cost_tracking(): +# # Here you create a fixture that will be used by your tests +# # Make sure the fixture returns TestClient(app) +# @pytest.fixture(autouse=True) +# def client(): +# with TestClient(app) as client: +# yield client + +# @pytest.mark.asyncio +# async def test_proxy_cost_tracking(client): # """ -# Get expected cost. +# Get min cost. # Create new key. # Run 10 parallel calls. # Check cost for key at the end. -# assert it's = expected cost. +# assert it's > min cost. # """ # model = "gpt-3.5-turbo" # messages = [{"role": "user", "content": "Hey, how's it going?"}] -# number_of_calls = 10 -# expected_cost = litellm.completion_cost(model=model, messages=messages) * number_of_calls -# async def litellm_acompletion(): +# number_of_calls = 1 +# min_cost = litellm.completion_cost(model=model, messages=messages) * number_of_calls +# try: +# ### CREATE NEW KEY ### +# test_data = { +# "models": ["azure-model"], +# } +# # Your bearer token +# token = os.getenv("PROXY_MASTER_KEY") +# headers = { +# "Authorization": f"Bearer {token}" +# } +# create_new_key = client.post("/key/generate", json=test_data, headers=headers) +# key = create_new_key.json()["key"] +# print(f"received key: {key}") +# ### MAKE PARALLEL CALLS ### +# async def test_chat_completions(): +# # Your test data +# test_data = { +# "model": "azure-model", +# "messages": messages +# } +# tmp_headers = { +# "Authorization": f"Bearer {key}" +# } +# response = client.post("/v1/chat/completions", json=test_data, headers=tmp_headers) + +# assert response.status_code == 200 +# result = response.json() +# print(f"Received response: {result}") +# tasks = [test_chat_completions() for _ in range(number_of_calls)] +# chat_completions = await asyncio.gather(*tasks) +# ### CHECK SPEND ### +# get_key_spend = client.get(f"/key/info?key={key}", headers=headers) + +# assert get_key_spend.json()["info"]["spend"] > min_cost +# # print(f"chat_completions: {chat_completions}") +# # except Exception as e: +# # pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + +# #### JUST TEST LOCAL PROXY SERVER + +# import requests, os +# from concurrent.futures import ThreadPoolExecutor +# import dotenv +# dotenv.load_dotenv() + +# api_url = "http://0.0.0.0:8000/chat/completions" + +# def make_api_call(api_url): +# # Your test data +# test_data = { +# "model": "azure-model", +# "messages": [ +# { +# "role": "user", +# "content": "hi" +# }, +# ], +# "max_tokens": 10, +# } +# # Your bearer token +# token = os.getenv("PROXY_MASTER_KEY") + +# headers = { +# "Authorization": f"Bearer {token}" +# } +# print("testing proxy server") +# response = requests.post(api_url, json=test_data, headers=headers) +# return response.json() + +# # Number of parallel API calls +# num_parallel_calls = 3 + +# # List to store results +# results = [] + +# # Create a ThreadPoolExecutor +# with ThreadPoolExecutor() as executor: +# # Submit the API calls concurrently +# futures = [executor.submit(make_api_call, api_url) for _ in range(num_parallel_calls)] + +# # Gather the results as they become available +# for future in futures: +# try: +# result = future.result() +# results.append(result) +# except Exception as e: +# print(f"Error: {e}") + +# # Print the results +# for idx, result in enumerate(results, start=1): +# print(f"Result {idx}: {result}") diff --git a/litellm/tests/test_proxy_server_keys.py b/litellm/tests/test_proxy_server_keys.py index 806b5f43e..a2dd396c0 100644 --- a/litellm/tests/test_proxy_server_keys.py +++ b/litellm/tests/test_proxy_server_keys.py @@ -59,6 +59,7 @@ def test_add_new_key(client): print(f"response: {response.text}") assert response.status_code == 200 result = response.json() + assert result["key"].startswith("sk-") print(f"Received response: {result}") except Exception as e: pytest.fail("LiteLLM Proxy test failed. Exception", e) diff --git a/litellm/utils.py b/litellm/utils.py index d3a9b8bb0..892fc010c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -741,13 +741,9 @@ class Logging: f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) pass - - def success_handler(self, result=None, start_time=None, end_time=None, **kwargs): - print_verbose( - f"Logging Details LiteLLM-Success Call" - ) - try: + def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None): + try: if start_time is None: start_time = self.start_time if end_time is None: @@ -776,6 +772,18 @@ class Logging: float_diff = float(time_diff) litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff) + return start_time, end_time, result, complete_streaming_response + except: + pass + + def success_handler(self, result=None, start_time=None, end_time=None, **kwargs): + print_verbose( + f"Logging Details LiteLLM-Success Call" + ) + try: + start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result) + print_verbose(f"success callbacks: {litellm.success_callback}") + for callback in litellm.success_callback: try: if callback == "lite_debugger": @@ -969,6 +977,29 @@ class Logging: ) pass + async def async_success_handler(self, result=None, start_time=None, end_time=None, **kwargs): + """ + Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. + """ + start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result) + print_verbose(f"success callbacks: {litellm.success_callback}") + + for callback in litellm._async_success_callback: + try: + if callable(callback): # custom logger functions + await customLogger.async_log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + callback_func=callback + ) + except: + print_verbose( + f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" + ) + def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): print_verbose( f"Logging Details LiteLLM-Failure Call" @@ -1185,6 +1216,17 @@ def client(original_function): callback_list=callback_list, function_id=function_id ) + ## ASYNC CALLBACKS + if len(litellm.success_callback) > 0: + removed_async_items = [] + for index, callback in enumerate(litellm.success_callback): + if inspect.iscoroutinefunction(callback): + litellm._async_success_callback.append(callback) + removed_async_items.append(index) + + # Pop the async items from success_callback in reverse order to avoid index issues + for index in reversed(removed_async_items): + litellm.success_callback.pop(index) if add_breadcrumb: add_breadcrumb( category="litellm.llm_call", @@ -1373,7 +1415,6 @@ def client(original_function): start_time = datetime.datetime.now() result = None logging_obj = kwargs.get("litellm_logging_obj", None) - # only set litellm_call_id if its not in kwargs if "litellm_call_id" not in kwargs: kwargs["litellm_call_id"] = str(uuid.uuid4()) @@ -1426,8 +1467,8 @@ def client(original_function): # [OPTIONAL] ADD TO CACHE if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object litellm.cache.add_cache(result, *args, **kwargs) - - # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated + # LOG SUCCESS - handle streaming success logging in the _next_ object + asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time)) threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() # RETURN RESULT if isinstance(result, ModelResponse): @@ -1465,7 +1506,6 @@ def client(original_function): logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this! raise e - # Use httpx to determine if the original function is a coroutine is_coroutine = inspect.iscoroutinefunction(original_function) # Return the appropriate wrapper based on the original function type @@ -5370,6 +5410,8 @@ class CustomStreamWrapper: processed_chunk = self.chunk_creator(chunk=chunk) if processed_chunk is None: continue + ## LOGGING + asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,)) return processed_chunk raise StopAsyncIteration else: # temporary patch for non-aiohttp async calls From 7adb8f493d0a9b3af16912075c2e45692e4209ea Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Dec 2023 17:01:46 -0800 Subject: [PATCH 046/125] docs(custom_callback.md): add async callbacks to docs --- .../docs/observability/custom_callback.md | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/docs/my-website/docs/observability/custom_callback.md b/docs/my-website/docs/observability/custom_callback.md index 580bd819a..78b7499a8 100644 --- a/docs/my-website/docs/observability/custom_callback.md +++ b/docs/my-website/docs/observability/custom_callback.md @@ -85,6 +85,43 @@ print(response) ``` +## Async Callback Functions + +LiteLLM currently supports just async success callback functions for async completion/embedding calls. + +```python +import asyncio, litellm + +async def async_test_logging_fn(kwargs, completion_obj, start_time, end_time): + print(f"On Async Success!") + +async def test_chat_openai(): + try: + # litellm.set_verbose = True + litellm.success_callback = [async_test_logging_fn] + response = await litellm.acompletion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm openai" + }], + stream=True) + async for chunk in response: + continue + except Exception as e: + print(e) + pytest.fail(f"An error occurred - {str(e)}") + +asyncio.run(test_chat_openai()) +``` + +:::info + +We're actively trying to expand this to other event types. [Tell us if you need this!](https://github.com/BerriAI/litellm/issues/1007) + + + +::: + ## What's in kwargs? Notice we pass in a kwargs argument to custom callback. From cba98cf530df53723873f1841ec7840218fbf302 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 16:30:02 -0800 Subject: [PATCH 047/125] (test) init router with 4 clients --- litellm/tests/test_router_init.py | 51 +++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 litellm/tests/test_router_init.py diff --git a/litellm/tests/test_router_init.py b/litellm/tests/test_router_init.py new file mode 100644 index 000000000..806894081 --- /dev/null +++ b/litellm/tests/test_router_init.py @@ -0,0 +1,51 @@ +# this tests if the router is intiaized correctly +import sys, os, time +import traceback, asyncio +import pytest +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm import Router +from concurrent.futures import ThreadPoolExecutor +from collections import defaultdict +from dotenv import load_dotenv +load_dotenv() + + +# everytime we load the router we should have 4 clients: +# Async +# Sync +# Async + Stream +# Sync + Stream + + +def test_init_clients(): + litellm.set_verbose = True + try: + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE") + }, + }, + ] + + + router = Router(model_list=model_list) + print(router.model_list) + for elem in router.model_list: + print(elem) + assert elem["client"] is not None + assert elem["async_client"] is not None + assert elem["stream_client"] is not None + assert elem["stream_async_client"] is not None + + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") +# test_init_clients() From 886b52d4480495514be7aa9054af01b46b6aa7a6 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 16:59:48 -0800 Subject: [PATCH 048/125] (test) init router clients --- litellm/tests/test_router_init.py | 97 ++++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 29 deletions(-) diff --git a/litellm/tests/test_router_init.py b/litellm/tests/test_router_init.py index 806894081..542242976 100644 --- a/litellm/tests/test_router_init.py +++ b/litellm/tests/test_router_init.py @@ -1,4 +1,4 @@ -# this tests if the router is intiaized correctly +# this tests if the router is initialized correctly import sys, os, time import traceback, asyncio import pytest @@ -12,40 +12,79 @@ from collections import defaultdict from dotenv import load_dotenv load_dotenv() - -# everytime we load the router we should have 4 clients: +# every time we load the router we should have 4 clients: # Async # Sync # Async + Stream # Sync + Stream - def test_init_clients(): - litellm.set_verbose = True - try: - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") - }, - }, - ] + litellm.set_verbose = True + try: + print("testing init 4 clients with diff timeouts") + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + "timeout": 0.01, + "stream_timeout": 0.000_001, + "max_retries": 7 + }, + }, + ] + router = Router(model_list=model_list) + for elem in router.model_list: + assert elem["client"] is not None + assert elem["async_client"] is not None + assert elem["stream_client"] is not None + assert elem["stream_async_client"] is not None + + # check if timeout for stream/non stream clients is set correctly + async_client = elem["async_client"] + stream_async_client = elem["stream_async_client"] + + assert async_client.timeout == 0.01 + assert stream_async_client.timeout == 0.000_001 + print("PASSED !") + + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + +test_init_clients() - router = Router(model_list=model_list) - print(router.model_list) - for elem in router.model_list: - print(elem) - assert elem["client"] is not None - assert elem["async_client"] is not None - assert elem["stream_client"] is not None - assert elem["stream_async_client"] is not None +def test_init_clients_basic(): + litellm.set_verbose = True + try: + print("Test basic client init") + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + }, + ] + router = Router(model_list=model_list) + for elem in router.model_list: + assert elem["client"] is not None + assert elem["async_client"] is not None + assert elem["stream_client"] is not None + assert elem["stream_async_client"] is not None + print("PASSED !") + + # see if we can init clients without timeout or max retries set + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + +test_init_clients_basic() - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") -# test_init_clients() From 19646091fdc8c2e315d9e2e2c4e2489eabd521c6 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 17:29:07 -0800 Subject: [PATCH 049/125] (feat) router: init stream, async stream, async, clients --- litellm/router.py | 97 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 85 insertions(+), 12 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 5bf06760e..edae794c9 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -188,7 +188,7 @@ class Router: data["model"] = original_model_string[:index_of_model_id] else: data["model"] = original_model_string - model_client = deployment.get("client", None) + model_client = self._get_client(deployment=deployment, kwargs=kwargs) return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs}) except Exception as e: raise e @@ -234,7 +234,7 @@ class Router: data["model"] = original_model_string[:index_of_model_id] else: data["model"] = original_model_string - model_client = deployment.get("async_client", None) + model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") self.total_calls[original_model_string] +=1 response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs}) self.success_calls[original_model_string] +=1 @@ -303,7 +303,7 @@ class Router: data["model"] = original_model_string[:index_of_model_id] else: data["model"] = original_model_string - model_client = deployment.get("client", None) + model_client = self._get_client(deployment=deployment, kwargs=kwargs) # call via litellm.embedding() return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs}) @@ -328,7 +328,7 @@ class Router: data["model"] = original_model_string[:index_of_model_id] else: data["model"] = original_model_string - model_client = deployment.get("async_client", None) + model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs}) @@ -857,19 +857,19 @@ class Router: if api_version and api_version.startswith("os.environ/"): api_version_env_name = api_version.replace("os.environ/", "") api_version = litellm.get_secret(api_version_env_name) - - timeout = litellm_params.get("timeout") - if timeout and timeout.startswith("os.environ/"): + + timeout = litellm_params.pop("timeout", None) + if isinstance(timeout, str) and timeout.startswith("os.environ/"): timeout_env_name = api_version.replace("os.environ/", "") timeout = litellm.get_secret(timeout_env_name) - stream_timeout = litellm_params.get("stream_timeout") - if stream_timeout and stream_timeout.startswith("os.environ/"): + stream_timeout = litellm_params.pop("stream_timeout", timeout) # if no stream_timeout is set, default to timeout + if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"): stream_timeout_env_name = api_version.replace("os.environ/", "") stream_timeout = litellm.get_secret(stream_timeout_env_name) - - max_retries = litellm_params.get("max_retries") - if max_retries and max_retries.startswith("os.environ/"): + + max_retries = litellm_params.pop("max_retries", 2) + if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): max_retries_env_name = api_version.replace("os.environ/", "") max_retries = litellm.get_secret(max_retries_env_name) @@ -898,6 +898,22 @@ class Router: timeout=timeout, max_retries=max_retries ) + + # streaming clients can have diff timeouts + model["stream_async_client"] = openai.AsyncAzureOpenAI( + api_key=api_key, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries + ) + model["stream_client"] = openai.AzureOpenAI( + api_key=api_key, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries + ) else: model["async_client"] = openai.AsyncAzureOpenAI( api_key=api_key, @@ -913,6 +929,23 @@ class Router: timeout=timeout, max_retries=max_retries ) + # streaming clients should have diff timeouts + model["stream_async_client"] = openai.AsyncAzureOpenAI( + api_key=api_key, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries + ) + + model["stream_client"] = openai.AzureOpenAI( + api_key=api_key, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries + ) + else: self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}") model["async_client"] = openai.AsyncOpenAI( @@ -927,6 +960,23 @@ class Router: timeout=timeout, max_retries=max_retries ) + + # streaming clients should have diff timeouts + model["stream_async_client"] = openai.AsyncOpenAI( + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, + max_retries=max_retries + ) + + # streaming clients should have diff timeouts + model["stream_client"] = openai.OpenAI( + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, + max_retries=max_retries + ) + ############ End of initializing Clients for OpenAI/Azure ################### model_id = "" for key in model["litellm_params"]: @@ -947,6 +997,29 @@ class Router: def get_model_names(self): return self.model_names + def _get_client(self, deployment, kwargs, client_type=None): + """ + Returns the appropriate client based on the given deployment, kwargs, and client_type. + + Parameters: + deployment (dict): The deployment dictionary containing the clients. + kwargs (dict): The keyword arguments passed to the function. + client_type (str): The type of client to return. + + Returns: + The appropriate client based on the given client_type and kwargs. + """ + if client_type == "async": + if kwargs.get("stream") == True: + return deployment["stream_async_client"] + else: + return deployment["async_client"] + else: + if kwargs.get("stream") == True: + return deployment["stream_client"] + else: + return deployment["client"] + def print_verbose(self, print_statement): if self.set_verbose or litellm.set_verbose: print(f"LiteLLM.Router: {print_statement}") # noqa From fa5b453d395f99a0bcb0650f280c1e4714373cc4 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 17:31:11 -0800 Subject: [PATCH 050/125] (test) init router --- litellm/tests/test_router_init.py | 104 +++++++++++++++++++++++++++++- 1 file changed, 102 insertions(+), 2 deletions(-) diff --git a/litellm/tests/test_router_init.py b/litellm/tests/test_router_init.py index 542242976..4d861365e 100644 --- a/litellm/tests/test_router_init.py +++ b/litellm/tests/test_router_init.py @@ -55,7 +55,7 @@ def test_init_clients(): traceback.print_exc() pytest.fail(f"Error occurred: {e}") -test_init_clients() +# test_init_clients() def test_init_clients_basic(): @@ -86,5 +86,105 @@ def test_init_clients_basic(): traceback.print_exc() pytest.fail(f"Error occurred: {e}") -test_init_clients_basic() +# test_init_clients_basic() + + +def test_timeouts_router(): + """ + Test the timeouts of the router with multiple clients. This HASas to raise a timeout error + """ + import openai + litellm.set_verbose = True + try: + print("testing init 4 clients with diff timeouts") + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + "timeout": 0.000001, + "stream_timeout": 0.000_001, + }, + }, + ] + router = Router(model_list=model_list) + + print("PASSED !") + async def test(): + try: + await router.acompletion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello, write a 20 pg essay" + } + ], + ) + except Exception as e: + raise e + asyncio.run(test()) + except openai.APITimeoutError as e: + print("Passed: Raised correct exception. Got openai.APITimeoutError\nGood Job", e) + print(type(e)) + pass + except Exception as e: + pytest.fail(f"Did not raise error `openai.APITimeoutError`. Instead raised error type: {type(e)}, Error: {e}") + +# test_timeouts_router() + + +def test_stream_timeouts_router(): + """ + Test the stream timeouts router. See if it selected the correct client with stream timeout + """ + import openai + + litellm.set_verbose = True + try: + print("testing init 4 clients with diff timeouts") + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + "timeout": 200, # regular calls will not timeout, stream calls will + "stream_timeout": 0.000_001, + }, + }, + ] + router = Router(model_list=model_list) + + print("PASSED !") + selected_client = router._get_client( + deployment=router.model_list[0], + kwargs={ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "hello, write a 20 pg essay" + } + ], + "stream": True + }, + client_type=None + ) + print("Select client timeout", selected_client.timeout) + assert selected_client.timeout == 0.000_001 + except openai.APITimeoutError as e: + print("Passed: Raised correct exception. Got openai.APITimeoutError\nGood Job", e) + print(type(e)) + pass + except Exception as e: + pytest.fail(f"Did not raise error `openai.APITimeoutError`. Instead raised error type: {type(e)}, Error: {e}") + +test_stream_timeouts_router() + From 3f541fe99991610df5e9ff8c8b88f98df8d17fab Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 17:37:09 -0800 Subject: [PATCH 051/125] (docs) custom timeouts proxy --- docs/my-website/docs/proxy/load_balancing.md | 28 ++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/docs/my-website/docs/proxy/load_balancing.md b/docs/my-website/docs/proxy/load_balancing.md index d9215244e..0e1f58662 100644 --- a/docs/my-website/docs/proxy/load_balancing.md +++ b/docs/my-website/docs/proxy/load_balancing.md @@ -111,3 +111,31 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \ } ' ``` + +## Custom Timeouts, Stream Timeouts - Per Model +For each model you can set `timeout` & `stream_timeout` under `litellm_params` +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: azure/gpt-turbo-small-eu + api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ + api_key: + timeout: 0.1 # timeout in (seconds) + stream_timeout: 0.01 # timeout stream requests (seconds) + max_retries: 5 + - model_name: gpt-3.5-turbo + litellm_params: + model: azure/gpt-turbo-small-ca + api_base: https://my-endpoint-canada-berri992.openai.azure.com/ + api_key: + timeout: 0.1 # timeout in (seconds) + stream_timeout: 0.01 # timeout stream requests (seconds) + max_retries: 5 + +``` + +#### Start Proxy +```shell +$ litellm --config /path/to/config.yaml +``` From acd1678d14c14c8d21cb75254b5fcda2641341a0 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 17:51:31 -0800 Subject: [PATCH 052/125] (docs) router --- docs/my-website/docs/proxy/load_balancing.md | 46 +++++++++++++++----- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/docs/my-website/docs/proxy/load_balancing.md b/docs/my-website/docs/proxy/load_balancing.md index 0e1f58662..a73cfe332 100644 --- a/docs/my-website/docs/proxy/load_balancing.md +++ b/docs/my-website/docs/proxy/load_balancing.md @@ -3,38 +3,39 @@ Load balance multiple instances of the same model The proxy will handle routing requests (using LiteLLM's Router). **Set `rpm` in the config if you want maximize throughput** +## Quick Start - Load Balancing +### Step 1 - Set deployments on config -#### Example config -requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo` +**Example config below**. Here requests with `model=gpt-3.5-turbo` will be routed across multiple instances of `azure/gpt-3.5-turbo` ```yaml model_list: - model_name: gpt-3.5-turbo litellm_params: - model: azure/gpt-turbo-small-eu - api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ - api_key: + model: azure/ + api_base: + api_key: rpm: 6 # Rate limit for this deployment: in requests per minute (rpm) - model_name: gpt-3.5-turbo litellm_params: model: azure/gpt-turbo-small-ca api_base: https://my-endpoint-canada-berri992.openai.azure.com/ - api_key: + api_key: rpm: 6 - model_name: gpt-3.5-turbo litellm_params: model: azure/gpt-turbo-large api_base: https://openai-france-1234.openai.azure.com/ - api_key: + api_key: rpm: 1440 ``` -#### Step 2: Start Proxy with config +### Step 2: Start Proxy with config ```shell $ litellm --config /path/to/config.yaml ``` -#### Step 3: Use proxy +### Step 3: Use proxy - Call a model group [Load Balancing] Curl Command ```shell curl --location 'http://0.0.0.0:8000/chat/completions' \ @@ -51,7 +52,28 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \ ' ``` -### Fallbacks + Cooldowns + Retries + Timeouts +### Usage - Call a specific model deployment +If you want to call a specific model defined in the `config.yaml`, you can call the `litellm_params: model` + +In this example it will call `azure/gpt-turbo-small-ca`. Defined in the config on Step 1 + +```bash +curl --location 'http://0.0.0.0:8000/chat/completions' \ +--header 'Content-Type: application/json' \ +--data ' { + "model": "azure/gpt-turbo-small-ca", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + } +' +``` + + +## Fallbacks + Cooldowns + Retries + Timeouts If a call fails after num_retries, fall back to another model group. @@ -122,7 +144,7 @@ model_list: api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ api_key: timeout: 0.1 # timeout in (seconds) - stream_timeout: 0.01 # timeout stream requests (seconds) + stream_timeout: 0.01 # timeout for stream requests (seconds) max_retries: 5 - model_name: gpt-3.5-turbo litellm_params: @@ -130,7 +152,7 @@ model_list: api_base: https://my-endpoint-canada-berri992.openai.azure.com/ api_key: timeout: 0.1 # timeout in (seconds) - stream_timeout: 0.01 # timeout stream requests (seconds) + stream_timeout: 0.01 # timeout for stream requests (seconds) max_retries: 5 ``` From 05f585153fdc4f79a483acdb621a25102c0436e0 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 18:01:21 -0800 Subject: [PATCH 053/125] (docs) add health check on load balancing --- docs/my-website/docs/proxy/load_balancing.md | 40 +++++++++++++++++++- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/docs/my-website/docs/proxy/load_balancing.md b/docs/my-website/docs/proxy/load_balancing.md index a73cfe332..e2e3a7ee6 100644 --- a/docs/my-website/docs/proxy/load_balancing.md +++ b/docs/my-website/docs/proxy/load_balancing.md @@ -107,7 +107,7 @@ model_list: litellm_settings: num_retries: 3 # retry call 3 times on each model_name (e.g. zephyr-beta) - request_timeout: 10 # raise Timeout error if call takes longer than 10s + request_timeout: 10 # raise Timeout error if call takes longer than 10s. Sets litellm.request_timeout fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo"]}] # fallback to gpt-3.5-turbo if call fails num_retries context_window_fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}] # fallback to gpt-3.5-turbo-16k if context window error allowed_fails: 3 # cooldown model if it fails > 1 call in a minute. @@ -129,7 +129,7 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \ "fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}], "context_window_fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}], "num_retries": 2, - "request_timeout": 10 + "timeout": 10 } ' ``` @@ -161,3 +161,39 @@ model_list: ```shell $ litellm --config /path/to/config.yaml ``` + + + +## Health Check LLMs on Proxy +Use this to health check all LLMs defined in your config.yaml +#### Request +Make a GET Request to `/health` on the proxy +```shell +curl --location 'http://0.0.0.0:8000/health' +``` + +You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you +``` +litellm --health +``` +#### Response +```shell +{ + "healthy_endpoints": [ + { + "model": "azure/gpt-35-turbo", + "api_base": "https://my-endpoint-canada-berri992.openai.azure.com/" + }, + { + "model": "azure/gpt-35-turbo", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/" + } + ], + "unhealthy_endpoints": [ + { + "model": "azure/gpt-35-turbo", + "api_base": "https://openai-france-1234.openai.azure.com/" + } + ] +} +``` \ No newline at end of file From ac486a3c4ae5ff61e3b2035c83cb717773e6b6af Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 18:08:57 -0800 Subject: [PATCH 054/125] (docs) add example config.yaml --- litellm/proxy/example_config_yaml/azure_config.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/litellm/proxy/example_config_yaml/azure_config.yaml b/litellm/proxy/example_config_yaml/azure_config.yaml index fd5865cd7..bd9ff9ac9 100644 --- a/litellm/proxy/example_config_yaml/azure_config.yaml +++ b/litellm/proxy/example_config_yaml/azure_config.yaml @@ -6,10 +6,16 @@ model_list: api_version: "2023-05-15" api_key: os.environ/AZURE_API_KEY tpm: 20_000 + timeout: 5 # 1 second timeout + stream_timeout: 0.5 # 0.5 second timeout for streaming requests + max_retries: 4 - model_name: gpt-4-team2 litellm_params: model: azure/gpt-4 api_key: os.environ/AZURE_API_KEY api_base: https://openai-gpt-4-test-v-2.openai.azure.com/ tpm: 100_000 + timeout: 5 # 1 second timeout + stream_timeout: 0.5 # 0.5 second timeout for streaming requests + max_retries: 4 From 51cddf1e975eb64ef68c150ff5d456fca79dd9a4 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 18:09:25 -0800 Subject: [PATCH 055/125] =?UTF-8?q?bump:=20version=201.10.3=20=E2=86=92=20?= =?UTF-8?q?1.10.4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a9878ced7..8c854d297 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.10.3" +version = "1.10.4" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.10.3" +version = "1.10.4" version_files = [ "pyproject.toml:^version" ] From 030bd220785765638b57d929a8e05b4be5303e45 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Dec 2023 18:32:47 -0800 Subject: [PATCH 056/125] feat(proxy_server.py): allow user to override api key auth --- litellm/proxy/custom_auth.py | 14 ++ litellm/proxy/proxy_server.py | 158 +++++------------- litellm/proxy/types.py | 70 ++++++++ litellm/proxy/utils.py | 60 ++++++- litellm/tests/test_configs/custom_auth.py | 14 ++ .../tests/{ => test_configs}/test_config.yaml | 0 .../test_configs/test_config_custom_auth.yaml | 11 ++ litellm/tests/test_proxy_custom_auth.py | 63 +++++++ litellm/tests/test_proxy_server_keys.py | 2 +- 9 files changed, 274 insertions(+), 118 deletions(-) create mode 100644 litellm/proxy/custom_auth.py create mode 100644 litellm/proxy/types.py create mode 100644 litellm/tests/test_configs/custom_auth.py rename litellm/tests/{ => test_configs}/test_config.yaml (100%) create mode 100644 litellm/tests/test_configs/test_config_custom_auth.yaml create mode 100644 litellm/tests/test_proxy_custom_auth.py diff --git a/litellm/proxy/custom_auth.py b/litellm/proxy/custom_auth.py new file mode 100644 index 000000000..0cce561ca --- /dev/null +++ b/litellm/proxy/custom_auth.py @@ -0,0 +1,14 @@ +from litellm.proxy.types import UserAPIKeyAuth +from fastapi import Request +from dotenv import load_dotenv +import os + +load_dotenv() +async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: + try: + modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234" + if api_key == modified_master_key: + return UserAPIKeyAuth(api_key=api_key) + raise Exception + except: + raise Exception \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8e9ddc9fa..6f8e0f6ab 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -92,12 +92,16 @@ def generate_feedback_box(): import litellm from litellm.proxy.utils import ( - PrismaClient + PrismaClient, + get_instance_fn ) +import pydantic +from litellm.proxy.types import * from litellm.caching import DualCache litellm.suppress_debug_info = True from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks from fastapi.routing import APIRouter +from fastapi.security import OAuth2PasswordBearer from fastapi.encoders import jsonable_encoder from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse from fastapi.middleware.cors import CORSMiddleware @@ -163,70 +167,8 @@ def log_input_output(request, response, custom_logger=None): return True from typing import Dict -from pydantic import BaseModel -######### Request Class Definition ###### -class ProxyChatCompletionRequest(BaseModel): - model: str - messages: List[Dict[str, str]] - temperature: Optional[float] = None - top_p: Optional[float] = None - n: Optional[int] = None - stream: Optional[bool] = None - stop: Optional[List[str]] = None - max_tokens: Optional[int] = None - presence_penalty: Optional[float] = None - frequency_penalty: Optional[float] = None - logit_bias: Optional[Dict[str, float]] = None - user: Optional[str] = None - response_format: Optional[Dict[str, str]] = None - seed: Optional[int] = None - tools: Optional[List[str]] = None - tool_choice: Optional[str] = None - functions: Optional[List[str]] = None # soon to be deprecated - function_call: Optional[str] = None # soon to be deprecated - - # Optional LiteLLM params - caching: Optional[bool] = None - api_base: Optional[str] = None - api_version: Optional[str] = None - api_key: Optional[str] = None - num_retries: Optional[int] = None - context_window_fallback_dict: Optional[Dict[str, str]] = None - fallbacks: Optional[List[str]] = None - metadata: Optional[Dict[str, str]] = {} - deployment_id: Optional[str] = None - request_timeout: Optional[int] = None - - class Config: - extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs) - -class ModelParams(BaseModel): - model_name: str - litellm_params: dict - model_info: Optional[dict] - class Config: - protected_namespaces = () - -class GenerateKeyRequest(BaseModel): - duration: str = "1h" - models: list = [] - aliases: dict = {} - config: dict = {} - spend: int = 0 - user_id: Optional[str] = None - -class GenerateKeyResponse(BaseModel): - key: str - expires: datetime - user_id: str - -class _DeleteKeyObject(BaseModel): - key: str - -class DeleteKeyRequest(BaseModel): - keys: List[_DeleteKeyObject] - +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") user_api_base = None user_model = None user_debug = False @@ -249,6 +191,7 @@ master_key = None otel_logging = False prisma_client: Optional[PrismaClient] = None user_api_key_cache = DualCache() +user_custom_auth = None ### REDIS QUEUE ### async_result = None celery_app_conn = None @@ -268,21 +211,21 @@ def usage_telemetry( target=litellm.utils.litellm_telemetry, args=(data,), daemon=True ).start() -api_key_header = APIKeyHeader(name="Authorization", auto_error=False) -async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)): - global master_key, prisma_client, llm_model_list - print(f"master_key - {master_key}; api_key - {api_key}") - if master_key is None: - if isinstance(api_key, str): - return { - "api_key": api_key.replace("Bearer ", "") - } - else: - return { - "api_key": api_key - } + +async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth: + global master_key, prisma_client, llm_model_list, user_custom_auth try: + ### USER-DEFINED AUTH FUNCTION ### + if user_custom_auth: + response = await user_custom_auth(request=request, api_key=api_key) + return UserAPIKeyAuth.model_validate(response) + + if master_key is None: + if isinstance(api_key, str): + return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", "")) + else: + return UserAPIKeyAuth() if api_key is None: raise Exception("No api key passed in.") route = request.url.path @@ -290,9 +233,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key) if is_master_key_valid: - return { - "api_key": master_key - } + return UserAPIKeyAuth(api_key=master_key) if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid: raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys") @@ -318,7 +259,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap return_dict = {"api_key": valid_token.token} if valid_token.user_id: return_dict["user_id"] = valid_token.user_id - return return_dict + return UserAPIKeyAuth(**return_dict) else: data = await request.json() model = data.get("model", None) @@ -329,14 +270,14 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap return_dict = {"api_key": valid_token.token} if valid_token.user_id: return_dict["user_id"] = valid_token.user_id - return return_dict + return UserAPIKeyAuth(**return_dict) else: raise Exception(f"Invalid token") except Exception as e: print(f"An exception occurred - {traceback.format_exc()}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail={"error": "invalid user key"}, + detail="invalid user key", ) def prisma_setup(database_url: Optional[str]): @@ -464,7 +405,7 @@ def run_ollama_serve(): """) def load_router_config(router: Optional[litellm.Router], config_file_path: str): - global master_key, user_config_file_path, otel_logging + global master_key, user_config_file_path, otel_logging, user_custom_auth config = {} try: if os.path.exists(config_file_path): @@ -499,7 +440,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): ### LOAD FROM AZURE KEY VAULT ### use_azure_key_vault = general_settings.get("use_azure_key_vault", False) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) - ### CONNECT TO DATABASE ### database_url = general_settings.get("database_url", None) if database_url and database_url.startswith("os.environ/"): @@ -514,12 +454,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): master_key = general_settings.get("master_key", None) if master_key and master_key.startswith("os.environ/"): master_key = litellm.get_secret(master_key) - #### OpenTelemetry Logging (OTEL) ######## otel_logging = general_settings.get("otel", False) if otel_logging == True: print("\nOpenTelemetry Logging Activated") - + ### CUSTOM API KEY AUTH ### + custom_auth = general_settings.get("custom_auth", None) + if custom_auth: + user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) litellm_settings = config.get('litellm_settings', None) if litellm_settings: @@ -549,23 +491,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): password=cache_password ) elif key == "callbacks": - print(f"{blue_color_code}\nSetting custom callbacks on Proxy") - passed_module, instance_name = value.split(".") - - # Dynamically import the module - module = importlib.import_module(passed_module) - # Get the instance from the module - instance = getattr(module, instance_name) - - methods = [method for method in dir(instance) if callable(getattr(instance, method))] - # Print the methods - print("Methods in the custom callbacks instance:") - for method in methods: - print(method) - - litellm.callbacks = [instance] - print() - + litellm.callbacks = [get_instance_fn(value=value)] else: setattr(litellm, key, value) @@ -844,7 +770,7 @@ def model_list(): @router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)]) -async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)): +async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)): try: body = await request.body() body_str = body.decode() @@ -853,7 +779,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key except: data = json.loads(body_str) - data["user"] = user_api_key_dict.get("user_id", None) + data["user"] = user_api_key_dict.user_id data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args @@ -864,9 +790,9 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key data["model"] = user_model data["call_type"] = "text_completion" if "metadata" in data: - data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] + data["metadata"]["user_api_key"] = user_api_key_dict.api_key else: - data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} + data["metadata"] = {"user_api_key": user_api_key_dict.api_key} return litellm_completion( **data @@ -888,7 +814,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key @router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint -async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): +async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): global general_settings, user_debug try: data = {} @@ -905,13 +831,13 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap # users can pass in 'user' param to /chat/completions. Don't override it if data.get("user", None) is None: # if users are using user_api_key_auth, set `user` in `data` - data["user"] = user_api_key_dict.get("user_id", None) + data["user"] = user_api_key_dict.user_id if "metadata" in data: - data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] + data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["headers"] = request.headers else: - data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} + data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["headers"] = request.headers global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli @@ -962,14 +888,14 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) -async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): +async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): try: # Use orjson to parse JSON data, orjson speeds up requests significantly body = await request.body() data = orjson.loads(body) - data["user"] = user_api_key_dict.get("user_id", None) + data["user"] = user_api_key_dict.user_id data["model"] = ( general_settings.get("embedding_model", None) # server default or user_model # model name passed via cli args @@ -978,9 +904,9 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap if user_model: data["model"] = user_model if "metadata" in data: - data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] + data["metadata"]["user_api_key"] = user_api_key_dict.api_key else: - data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} + data["metadata"] = {"user_api_key": user_api_key_dict.api_key} ## ROUTE TO CORRECT ENDPOINT ## router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] diff --git a/litellm/proxy/types.py b/litellm/proxy/types.py new file mode 100644 index 000000000..fbee0732b --- /dev/null +++ b/litellm/proxy/types.py @@ -0,0 +1,70 @@ +from pydantic import BaseModel +from typing import Optional, List, Union, Dict +from datetime import datetime + +######### Request Class Definition ###### +class ProxyChatCompletionRequest(BaseModel): + model: str + messages: List[Dict[str, str]] + temperature: Optional[float] = None + top_p: Optional[float] = None + n: Optional[int] = None + stream: Optional[bool] = None + stop: Optional[List[str]] = None + max_tokens: Optional[int] = None + presence_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + response_format: Optional[Dict[str, str]] = None + seed: Optional[int] = None + tools: Optional[List[str]] = None + tool_choice: Optional[str] = None + functions: Optional[List[str]] = None # soon to be deprecated + function_call: Optional[str] = None # soon to be deprecated + + # Optional LiteLLM params + caching: Optional[bool] = None + api_base: Optional[str] = None + api_version: Optional[str] = None + api_key: Optional[str] = None + num_retries: Optional[int] = None + context_window_fallback_dict: Optional[Dict[str, str]] = None + fallbacks: Optional[List[str]] = None + metadata: Optional[Dict[str, str]] = {} + deployment_id: Optional[str] = None + request_timeout: Optional[int] = None + + class Config: + extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs) + +class ModelParams(BaseModel): + model_name: str + litellm_params: dict + model_info: Optional[dict] + class Config: + protected_namespaces = () + +class GenerateKeyRequest(BaseModel): + duration: str = "1h" + models: list = [] + aliases: dict = {} + config: dict = {} + spend: int = 0 + user_id: Optional[str] = None + +class GenerateKeyResponse(BaseModel): + key: str + expires: datetime + user_id: str + +class _DeleteKeyObject(BaseModel): + key: str + +class DeleteKeyRequest(BaseModel): + keys: List[_DeleteKeyObject] + + +class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth + api_key: Optional[str] = None + user_id: Optional[str] = None \ No newline at end of file diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 1ea7f47a0..5b2039543 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,6 +1,7 @@ from typing import Optional, List, Any -import os, subprocess, hashlib +import os, subprocess, hashlib, importlib +### DB CONNECTOR ### class PrismaClient: def __init__(self, database_url: str): print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") @@ -95,3 +96,60 @@ class PrismaClient: async def disconnect(self): await self.db.disconnect() +# ### CUSTOM FILE ### +# def get_instance_fn(value: str, config_file_path: Optional[str]=None): +# try: +# # Split the path by dots to separate module from instance +# parts = value.split(".") +# # The module path is all but the last part, and the instance is the last part +# module_path = ".".join(parts[:-1]) +# instance_name = parts[-1] + +# if config_file_path is not None: +# directory = os.path.dirname(config_file_path) +# module_path = os.path.join(directory, module_path) +# # Dynamically import the module +# module = importlib.import_module(module_path) + +# # Get the instance from the module +# instance = getattr(module, instance_name) + +# return instance +# except ImportError as e: +# print(e) +# raise ImportError(f"Could not import file at {value}") + +def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: + try: + print(f"value: {value}") + # Split the path by dots to separate module from instance + parts = value.split(".") + + # The module path is all but the last part, and the instance_name is the last part + module_name = ".".join(parts[:-1]) + instance_name = parts[-1] + + # If config_file_path is provided, use it to determine the module spec and load the module + if config_file_path is not None: + directory = os.path.dirname(config_file_path) + module_file_path = os.path.join(directory, *module_name.split('.')) + module_file_path += '.py' + + spec = importlib.util.spec_from_file_location(module_name, module_file_path) + if spec is None: + raise ImportError(f"Could not find a module specification for {module_file_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + else: + # Dynamically import the module + module = importlib.import_module(module_name) + + # Get the instance from the module + instance = getattr(module, instance_name) + + return instance + except ImportError as e: + # Print the error message for easier debugging + print(e) + # Re-raise the exception with a user-friendly message + raise ImportError(f"Could not import {instance_name} from {module_name}") from e \ No newline at end of file diff --git a/litellm/tests/test_configs/custom_auth.py b/litellm/tests/test_configs/custom_auth.py new file mode 100644 index 000000000..f9de3a97a --- /dev/null +++ b/litellm/tests/test_configs/custom_auth.py @@ -0,0 +1,14 @@ +from litellm.proxy.types import UserAPIKeyAuth +from fastapi import Request +from dotenv import load_dotenv +import os + +load_dotenv() +async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: + try: + print(f"api_key: {api_key}") + if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234": + return UserAPIKeyAuth(api_key=api_key) + raise Exception + except: + raise Exception \ No newline at end of file diff --git a/litellm/tests/test_config.yaml b/litellm/tests/test_configs/test_config.yaml similarity index 100% rename from litellm/tests/test_config.yaml rename to litellm/tests/test_configs/test_config.yaml diff --git a/litellm/tests/test_configs/test_config_custom_auth.yaml b/litellm/tests/test_configs/test_config_custom_auth.yaml new file mode 100644 index 000000000..33088bd1c --- /dev/null +++ b/litellm/tests/test_configs/test_config_custom_auth.yaml @@ -0,0 +1,11 @@ +model_list: + - model_name: "openai-model" + litellm_params: + model: "gpt-3.5-turbo" + +litellm_settings: + drop_params: True + set_verbose: True + +general_settings: + custom_auth: custom_auth.user_api_key_auth \ No newline at end of file diff --git a/litellm/tests/test_proxy_custom_auth.py b/litellm/tests/test_proxy_custom_auth.py new file mode 100644 index 000000000..fa1b5f6dd --- /dev/null +++ b/litellm/tests/test_proxy_custom_auth.py @@ -0,0 +1,63 @@ +import sys, os +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os, io + +# this file is to test litellm/proxy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import embedding, completion, completion_cost, Timeout +from litellm import RateLimitError + +# test /chat/completion request to the proxy +from fastapi.testclient import TestClient +from fastapi import FastAPI +from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined +filepath = os.path.dirname(os.path.abspath(__file__)) +config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" +save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) +app = FastAPI() +app.include_router(router) # Include your router in the test app +@app.on_event("startup") +async def wrapper_startup_event(): + await startup_event() + +# Here you create a fixture that will be used by your tests +# Make sure the fixture returns TestClient(app) +@pytest.fixture(autouse=True) +def client(): + with TestClient(app) as client: + yield client + +def test_custom_auth(client): + try: + # Your test data + test_data = { + "model": "openai-model", + "messages": [ + { + "role": "user", + "content": "hi" + }, + ], + "max_tokens": 10, + } + # Your bearer token + token = os.getenv("PROXY_MASTER_KEY") + + headers = { + "Authorization": f"Bearer {token}" + } + response = client.post("/chat/completions", json=test_data, headers=headers) + print(f"response: {response.text}") + assert response.status_code == 401 + result = response.json() + print(f"Received response: {result}") + except Exception as e: + pytest.fail("LiteLLM Proxy test failed. Exception", e) \ No newline at end of file diff --git a/litellm/tests/test_proxy_server_keys.py b/litellm/tests/test_proxy_server_keys.py index a2dd396c0..fb0ec2f3c 100644 --- a/litellm/tests/test_proxy_server_keys.py +++ b/litellm/tests/test_proxy_server_keys.py @@ -25,7 +25,7 @@ from fastapi.testclient import TestClient from fastapi import FastAPI from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined filepath = os.path.dirname(os.path.abspath(__file__)) -config_fp = f"{filepath}/test_config.yaml" +config_fp = f"{filepath}/test_configs/test_config.yaml" save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) app = FastAPI() app.include_router(router) # Include your router in the test app From 31f3187670d22a72affd55717e8adb7d9b782980 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Dec 2023 18:43:01 -0800 Subject: [PATCH 057/125] test: fix linting errors --- litellm/proxy/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 5b2039543..9c6a2c17e 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -139,7 +139,7 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: if spec is None: raise ImportError(f"Could not find a module specification for {module_file_path}") module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + spec.loader.exec_module(module) # type: ignore else: # Dynamically import the module module = importlib.import_module(module_name) @@ -149,7 +149,7 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: return instance except ImportError as e: - # Print the error message for easier debugging - print(e) # Re-raise the exception with a user-friendly message - raise ImportError(f"Could not import {instance_name} from {module_name}") from e \ No newline at end of file + raise ImportError(f"Could not import {instance_name} from {module_name}") from e + except Exception as e: + raise e \ No newline at end of file From 9ba17657ad664a21b5e91259a152db58540be024 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 20:50:06 -0800 Subject: [PATCH 058/125] (feat) init redis cache with **kwargs --- litellm/caching.py | 10 ++++++---- litellm/tests/test_caching.py | 27 +++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 5e8fcf447..d9b94b958 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -69,10 +69,10 @@ class InMemoryCache(BaseCache): class RedisCache(BaseCache): - def __init__(self, host, port, password): + def __init__(self, host, port, password, **kwargs): import redis # if users don't provider one, use the default litellm cache - self.redis_client = redis.Redis(host=host, port=port, password=password) + self.redis_client = redis.Redis(host=host, port=port, password=password, **kwargs) def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) @@ -168,7 +168,8 @@ class Cache: type="local", host=None, port=None, - password=None + password=None, + **kwargs ): """ Initializes the cache based on the given type. @@ -178,6 +179,7 @@ class Cache: host (str, optional): The host address for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis". + **kwargs: Additional keyword arguments for redis.Redis() cache Raises: ValueError: If an invalid cache type is provided. @@ -186,7 +188,7 @@ class Cache: None """ if type == "redis": - self.cache = RedisCache(host, port, password) + self.cache = RedisCache(host, port, password, **kwargs) if type == "local": self.cache = InMemoryCache() if "cache" not in litellm.input_callback: diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index ab24d3e70..713f97b3e 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -90,7 +90,7 @@ def test_embedding_caching(): print(f"embedding2: {embedding2}") pytest.fail("Error occurred: Embedding caching failed") -test_embedding_caching() +# test_embedding_caching() def test_embedding_caching_azure(): @@ -190,7 +190,7 @@ def test_redis_cache_completion(): print(f"response4: {response4}") pytest.fail(f"Error occurred:") -test_redis_cache_completion() +# test_redis_cache_completion() # redis cache with custom keys def custom_get_cache_key(*args, **kwargs): @@ -231,6 +231,29 @@ def test_custom_redis_cache_with_key(): # test_custom_redis_cache_with_key() + +def test_custom_redis_cache_params(): + # test if we can init redis with **kwargs + try: + litellm.cache = Cache( + type="redis", + host=os.environ['REDIS_HOST'], + port=os.environ['REDIS_PORT'], + password=os.environ['REDIS_PASSWORD'], + db = 0, + ssl=True, + ssl_certfile="./redis_user.crt", + ssl_keyfile="./redis_user_private.key", + ssl_ca_certs="./redis_ca.pem", + ) + + print(litellm.cache.cache.redis_client) + litellm.cache = None + except Exception as e: + pytest.fail(f"Error occurred:", e) + +# test_custom_redis_cache_params() + # def test_redis_cache_with_ttl(): # cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) # sample_model_response_object_str = """{ From 1ff8f757527693f5d002ed84113fd55ddc64ced9 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Mon, 4 Dec 2023 21:19:32 -0800 Subject: [PATCH 059/125] Updated config.yml --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 8f2a89846..caf0e8396 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -45,7 +45,7 @@ jobs: command: | cd litellm python -m pip install types-requests types-setuptools types-redis - if ! python -m mypy . --ignore-missing-imports --explicit-package-bases; then + if ! python -m mypy . --ignore-missing-imports; then echo "mypy detected errors" exit 1 fi From 71e64c34cb403f1d550145f2cf370f612871ce04 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 07:28:03 -0800 Subject: [PATCH 060/125] fix(huggingface_restapi.py): raise better exceptions for unprocessable hf responses --- litellm/llms/huggingface_restapi.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 25aa1c574..c347910f8 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -170,6 +170,11 @@ class Huggingface(BaseLLM): "content" ] = completion_response["generated_text"] # type: ignore elif task == "text-generation-inference": + if (not isinstance(completion_response, list) + or not isinstance(completion_response[0], dict) + or "generated_text" not in completion_response[0]): + raise HuggingfaceError(status_code=422, message=f"response is not in expected format - {completion_response}") + if len(completion_response[0]["generated_text"]) > 0: model_response["choices"][0]["message"][ "content" From 943bf53b0bd7c2b5ad8a77788674635e5919d5df Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 07:28:14 -0800 Subject: [PATCH 061/125] =?UTF-8?q?bump:=20version=201.10.4=20=E2=86=92=20?= =?UTF-8?q?1.10.5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8c854d297..2e4b11e82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.10.4" +version = "1.10.5" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.10.4" +version = "1.10.5" version_files = [ "pyproject.toml:^version" ] From 13261287ecbbcf8af1e8956a476bb90f67f2f96a Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 09:05:53 -0800 Subject: [PATCH 062/125] (fix) proxy: bug non OpenAI LLMs --- litellm/router.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index edae794c9..75fd5afd9 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1011,14 +1011,14 @@ class Router: """ if client_type == "async": if kwargs.get("stream") == True: - return deployment["stream_async_client"] + return deployment.get("stream_async_client", None) else: - return deployment["async_client"] + return deployment.get("async_client", None) else: if kwargs.get("stream") == True: - return deployment["stream_client"] + return deployment.get("stream_client", None) else: - return deployment["client"] + return deployment.get("client", None) def print_verbose(self, print_statement): if self.set_verbose or litellm.set_verbose: From 3bdf61f02a5232db7dfcdddf9f24b8852ee2cdb8 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 09:35:08 -0800 Subject: [PATCH 063/125] (test) test bedrock on router --- litellm/tests/test_router.py | 37 +++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 8024c9dd2..ae9a06c41 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -329,4 +329,39 @@ def test_azure_aembedding_on_router(): except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") -# test_azure_aembedding_on_router() \ No newline at end of file +# test_azure_aembedding_on_router() + + +def test_bedrock_on_router(): + litellm.set_verbose = True + print("\n Testing bedrock on router\n") + try: + model_list = [ + { + "model_name": "claude-v1", + "litellm_params": { + "model": "bedrock/anthropic.claude-instant-v1", + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + + async def test(): + router = Router(model_list=model_list) + response = await router.acompletion( + model="claude-v1", + messages=[ + { + "role": "user", + "content": "hello from litellm test", + } + ] + ) + print(response) + router.reset() + asyncio.run(test()) + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") +# test_bedrock_on_router() \ No newline at end of file From 732a049513af21681e63c757e7c25d4a9293914c Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 09:36:38 -0800 Subject: [PATCH 064/125] (fix) patch max_retries for non openai llms --- litellm/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/utils.py b/litellm/utils.py index 892fc010c..ea0b4002f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1985,6 +1985,8 @@ def get_optional_params( # use the openai defaults if k not in supported_params: if k == "n" and n == 1: # langchain sends n=1 as a default value pass + if k == "max_retries": # TODO: This is a patch. We support max retries for OpenAI, Azure. For non OpenAI LLMs we need to add support for max retries + pass # Always keeps this in elif code blocks else: unsupported_params[k] = non_default_params[k] From 33cf5a3371b393d8453c02ded12dcc128e9aafa3 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 09:36:59 -0800 Subject: [PATCH 065/125] =?UTF-8?q?bump:=20version=201.10.5=20=E2=86=92=20?= =?UTF-8?q?1.10.6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2e4b11e82..029e925a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.10.5" +version = "1.10.6" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.10.5" +version = "1.10.6" version_files = [ "pyproject.toml:^version" ] From a602d59645f573b9cb6d3467c4da13f9b3451888 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 10:00:54 -0800 Subject: [PATCH 066/125] (fix) bug in completion: _check_valid_arg --- litellm/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index ea0b4002f..84cd20729 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1984,9 +1984,9 @@ def get_optional_params( # use the openai defaults for k in non_default_params.keys(): if k not in supported_params: if k == "n" and n == 1: # langchain sends n=1 as a default value - pass + continue # skip this param if k == "max_retries": # TODO: This is a patch. We support max retries for OpenAI, Azure. For non OpenAI LLMs we need to add support for max retries - pass + continue # skip this param # Always keeps this in elif code blocks else: unsupported_params[k] = non_default_params[k] From b46c73a46e3dd76391be21a8a66753cb9352681b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 11:13:09 -0800 Subject: [PATCH 067/125] fix: fix proxy testing --- litellm/proxy/proxy_server.py | 38 +++++++++------- litellm/tests/test_configs/test_config.yaml | 49 +++++++++++---------- litellm/tests/test_proxy_custom_auth.py | 4 +- litellm/tests/test_proxy_server.py | 35 +++++++++------ 4 files changed, 73 insertions(+), 53 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6f8e0f6ab..9d597ac01 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -168,7 +168,7 @@ def log_input_output(request, response, custom_logger=None): from typing import Dict -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) user_api_base = None user_model = None user_debug = False @@ -213,9 +213,13 @@ def usage_telemetry( -async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth: +async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth: global master_key, prisma_client, llm_model_list, user_custom_auth try: + if isinstance(api_key, str): + assert api_key.startswith("Bearer ") # ensure Bearer token passed in + api_key = api_key.replace("Bearer ", "") # extract the token + print(f"api_key: {api_key}; master_key: {master_key}; user_custom_auth: {user_custom_auth}") ### USER-DEFINED AUTH FUNCTION ### if user_custom_auth: response = await user_custom_auth(request=request, api_key=api_key) @@ -223,15 +227,16 @@ async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_sche if master_key is None: if isinstance(api_key, str): - return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", "")) - else: - return UserAPIKeyAuth() - if api_key is None: + return UserAPIKeyAuth(api_key=api_key) + else: + return UserAPIKeyAuth() + + if api_key is None: # only require api key if master key is set raise Exception("No api key passed in.") route = request.url.path # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead - is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key) + is_master_key_valid = secrets.compare_digest(api_key, master_key) if is_master_key_valid: return UserAPIKeyAuth(api_key=master_key) @@ -241,9 +246,9 @@ async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_sche if prisma_client: ## check for cache hit (In-Memory Cache) valid_token = user_api_key_cache.get_cache(key=api_key) - if valid_token is None and "Bearer " in api_key: + if valid_token is None: ## check db - cleaned_api_key = api_key[len("Bearer "):] + cleaned_api_key = api_key valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow()) user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) elif valid_token is not None: @@ -275,10 +280,10 @@ async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_sche raise Exception(f"Invalid token") except Exception as e: print(f"An exception occurred - {traceback.format_exc()}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="invalid user key", - ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="invalid user key", + ) def prisma_setup(database_url: Optional[str]): global prisma_client @@ -597,13 +602,17 @@ def initialize( config, use_queue ): - global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings + global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth generate_feedback_box() user_model = model user_debug = debug dynamic_config = {"general": {}, user_model: {}} if config: llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config) + else: + # reset auth if config not passed, needed for consecutive tests on proxy + master_key = None + user_custom_auth = None if headers: # model-specific param user_headers = headers dynamic_config[user_model]["headers"] = headers @@ -810,7 +819,6 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key detail=error_msg ) - @router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint diff --git a/litellm/tests/test_configs/test_config.yaml b/litellm/tests/test_configs/test_config.yaml index 34b3d928a..fa2079666 100644 --- a/litellm/tests/test_configs/test_config.yaml +++ b/litellm/tests/test_configs/test_config.yaml @@ -1,24 +1,27 @@ -model_list: - - model_name: "azure-model" - litellm_params: - model: "azure/gpt-35-turbo" - api_key: "os.environ/AZURE_EUROPE_API_KEY" - api_base: "https://my-endpoint-europe-berri-992.openai.azure.com/" - - model_name: "azure-model" - litellm_params: - model: "azure/gpt-35-turbo" - api_key: "os.environ/AZURE_CANADA_API_KEY" - api_base: "https://my-endpoint-canada-berri992.openai.azure.com" - - model_name: "azure-model" - litellm_params: - model: "azure/gpt-turbo" - api_key: "os.environ/AZURE_FRANCE_API_KEY" - api_base: "https://openai-france-1234.openai.azure.com" - -litellm_settings: - drop_params: True - set_verbose: True - general_settings: - master_key: "os.environ/PROXY_MASTER_KEY" - database_url: "os.environ/PROXY_DATABASE_URL" # [OPTIONAL] use for token-based auth to proxy + database_url: os.environ/PROXY_DATABASE_URL + master_key: os.environ/PROXY_MASTER_KEY +litellm_settings: + drop_params: true + set_verbose: true +model_list: +- litellm_params: + api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ + api_key: os.environ/AZURE_EUROPE_API_KEY + model: azure/gpt-35-turbo + model_name: azure-model +- litellm_params: + api_base: https://my-endpoint-canada-berri992.openai.azure.com + api_key: os.environ/AZURE_CANADA_API_KEY + model: azure/gpt-35-turbo + model_name: azure-model +- litellm_params: + api_base: https://openai-france-1234.openai.azure.com + api_key: os.environ/AZURE_FRANCE_API_KEY + model: azure/gpt-turbo + model_name: azure-model +- litellm_params: + model: gpt-3.5-turbo + model_info: + description: this is a test openai model + model_name: test_openai_models diff --git a/litellm/tests/test_proxy_custom_auth.py b/litellm/tests/test_proxy_custom_auth.py index fa1b5f6dd..5708b1c41 100644 --- a/litellm/tests/test_proxy_custom_auth.py +++ b/litellm/tests/test_proxy_custom_auth.py @@ -18,7 +18,7 @@ from litellm import RateLimitError # test /chat/completion request to the proxy from fastapi.testclient import TestClient from fastapi import FastAPI -from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined +from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) @@ -26,7 +26,7 @@ app = FastAPI() app.include_router(router) # Include your router in the test app @app.on_event("startup") async def wrapper_startup_event(): - await startup_event() + initialize(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) # Here you create a fixture that will be used by your tests # Make sure the fixture returns TestClient(app) diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index a525f01bf..b15ee8307 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -18,11 +18,22 @@ from litellm import RateLimitError # test /chat/completion request to the proxy from fastapi.testclient import TestClient from fastapi import FastAPI -from litellm.proxy.proxy_server import router # Replace with the actual module where your FastAPI router is defined +from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined +save_worker_config(config=None, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) app = FastAPI() app.include_router(router) # Include your router in the test app -client = TestClient(app) -def test_chat_completion(): +@app.on_event("startup") +async def wrapper_startup_event(): # required to reset config on app init - b/c pytest collects across multiple files - which sets the fastapi client + WORKER CONFIG to whatever was collected last + initialize(config=None, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) + +# Here you create a fixture that will be used by your tests +# Make sure the fixture returns TestClient(app) +@pytest.fixture(autouse=True) +def client(): + with TestClient(app) as client: + yield client + +def test_chat_completion(client): try: # Your test data test_data = { @@ -37,18 +48,16 @@ def test_chat_completion(): } print("testing proxy server") response = client.post("/v1/chat/completions", json=test_data) - + print(f"response - {response.text}") assert response.status_code == 200 result = response.json() print(f"Received response: {result}") except Exception as e: - pytest.fail("LiteLLM Proxy test failed. Exception", e) + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") # Run the test -# test_chat_completion() - -def test_chat_completion_azure(): +def test_chat_completion_azure(client): try: # Your test data test_data = { @@ -69,13 +78,13 @@ def test_chat_completion_azure(): print(f"Received response: {result}") assert len(result["choices"][0]["message"]["content"]) > 0 except Exception as e: - pytest.fail("LiteLLM Proxy test failed. Exception", e) + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") # Run the test # test_chat_completion_azure() -def test_embedding(): +def test_embedding(client): try: test_data = { "model": "azure/azure-embedding-model", @@ -89,13 +98,13 @@ def test_embedding(): print(len(result["data"][0]["embedding"])) assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so except Exception as e: - pytest.fail("LiteLLM Proxy test failed. Exception", e) + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") # Run the test # test_embedding() -def test_add_new_model(): +def test_add_new_model(client): try: test_data = { "model_name": "test_openai_models", @@ -135,7 +144,7 @@ class MyCustomHandler(CustomLogger): customHandler = MyCustomHandler() -def test_chat_completion_optional_params(): +def test_chat_completion_optional_params(client): # [PROXY: PROD TEST] - DO NOT DELETE # This tests if all the /chat/completion params are passed to litellm From ddea62fdb1538c59d17f9867a27a8470fc86c36d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 11:14:16 -0800 Subject: [PATCH 068/125] refactor(proxy_server.py): clean up print statements in proxy server --- litellm/proxy/proxy_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9d597ac01..ecd4ab8d6 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -219,7 +219,6 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap if isinstance(api_key, str): assert api_key.startswith("Bearer ") # ensure Bearer token passed in api_key = api_key.replace("Bearer ", "") # extract the token - print(f"api_key: {api_key}; master_key: {master_key}; user_custom_auth: {user_custom_auth}") ### USER-DEFINED AUTH FUNCTION ### if user_custom_auth: response = await user_custom_auth(request=request, api_key=api_key) From f0c704f3c2fca2ffec6f1d2a8d5eceebc581b470 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 11:45:07 -0800 Subject: [PATCH 069/125] (docs) add example on using proxy with OpenAI --- docs/my-website/docs/proxy/quick_start.md | 66 +++++++++++------------ 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/docs/my-website/docs/proxy/quick_start.md b/docs/my-website/docs/proxy/quick_start.md index f1749bc50..6ea6433df 100644 --- a/docs/my-website/docs/proxy/quick_start.md +++ b/docs/my-website/docs/proxy/quick_start.md @@ -43,7 +43,7 @@ litellm --test This will now automatically route any requests for gpt-3.5-turbo to bigcode starcoder, hosted on huggingface inference endpoints. -### Using LiteLLM Proxy - Curl Request, OpenAI Package +### Using LiteLLM Proxy - Curl Request, OpenAI Package, Langchain, Langchain JS @@ -84,7 +84,38 @@ print(response) ``` + +```python +from langchain.chat_models import ChatOpenAI +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain.schema import HumanMessage, SystemMessage + +chat = ChatOpenAI( + openai_api_base="http://0.0.0.0:8000", + model = "gpt-3.5-turbo", + temperature=0.1 +) + +messages = [ + SystemMessage( + content="You are a helpful assistant that im using to make a test request to." + ), + HumanMessage( + content="test from litellm. tell me why it's amazing in 1 sentence" + ), +] +response = chat(messages) + +print(response) + +``` + + @@ -474,37 +505,4 @@ curl -X POST \ https://api.openai.com/v1/chat/completions \ -H 'content-type: application/json' -H 'Authorization: Bearer sk-qnWGUIW9****************************************' \ -d '{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "this is a test request, write a short poem"}]}' -``` - -## Health Check LLMs on Proxy -Use this to health check all LLMs defined in your config.yaml -#### Request -```shell -curl --location 'http://0.0.0.0:8000/health' -``` - -You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you -``` -litellm --health -``` -#### Response -```shell -{ - "healthy_endpoints": [ - { - "model": "azure/gpt-35-turbo", - "api_base": "https://my-endpoint-canada-berri992.openai.azure.com/" - }, - { - "model": "azure/gpt-35-turbo", - "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/" - } - ], - "unhealthy_endpoints": [ - { - "model": "azure/gpt-35-turbo", - "api_base": "https://openai-france-1234.openai.azure.com/" - } - ] -} ``` \ No newline at end of file From a0f8bf23abcc61ec7c67f25b313016cfe4ac45e6 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 11:45:34 -0800 Subject: [PATCH 070/125] (test) proxy: langchain compatible --- litellm/proxy/tests/test_langchain_request.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 litellm/proxy/tests/test_langchain_request.py diff --git a/litellm/proxy/tests/test_langchain_request.py b/litellm/proxy/tests/test_langchain_request.py new file mode 100644 index 000000000..af6691f3c --- /dev/null +++ b/litellm/proxy/tests/test_langchain_request.py @@ -0,0 +1,28 @@ +from langchain.chat_models import ChatOpenAI +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain.schema import HumanMessage, SystemMessage + +chat = ChatOpenAI( + openai_api_base="http://0.0.0.0:8000", + model = "gpt-3.5-turbo", + temperature=0.1 +) + +messages = [ + SystemMessage( + content="You are a helpful assistant that im using to make a test request to." + ), + HumanMessage( + content="test from litellm. tell me why it's amazing in 1 sentence" + ), +] +response = chat(messages) + +print(response) + + + From 41075dc977c1529af88c92f0b1e83452022e406b Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 11:46:42 -0800 Subject: [PATCH 071/125] (docs) litellm proxy + langchain --- docs/my-website/docs/proxy/quick_start.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/my-website/docs/proxy/quick_start.md b/docs/my-website/docs/proxy/quick_start.md index 6ea6433df..b88e2b610 100644 --- a/docs/my-website/docs/proxy/quick_start.md +++ b/docs/my-website/docs/proxy/quick_start.md @@ -96,7 +96,7 @@ from langchain.prompts.chat import ( from langchain.schema import HumanMessage, SystemMessage chat = ChatOpenAI( - openai_api_base="http://0.0.0.0:8000", + openai_api_base="http://0.0.0.0:8000", # set openai_api_base to the LiteLLM Proxy model = "gpt-3.5-turbo", temperature=0.1 ) @@ -112,7 +112,6 @@ messages = [ response = chat(messages) print(response) - ``` From a85a9d7e0004df077f237d0052eae69afad989aa Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 11:55:02 -0800 Subject: [PATCH 072/125] (docs) using proxy with curl, OpenAI, langchain --- docs/my-website/docs/proxy/configs.md | 75 +++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 71ce7de02..ed5424b1a 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -1,3 +1,7 @@ +import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # Proxy Config.yaml Set model list, `api_base`, `api_key`, `temperature` & proxy server settings (`master-key`) on the config.yaml. @@ -26,6 +30,9 @@ model_list: api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ api_key: "os.environ/AZURE_API_KEY_EU" # does os.getenv("AZURE_API_KEY_EU") rpm: 6 # Rate limit for this deployment: in requests per minute (rpm) + - model_name: bedrock-claude-v1 + litellm_params: + model: bedrock/anthropic.claude-instant-v1 - model_name: gpt-3.5-turbo litellm_params: model: azure/gpt-turbo-small-ca @@ -54,13 +61,18 @@ general_settings: $ litellm --config /path/to/config.yaml ``` -#### Step 3: Use proxy -Curl Command + +### Using Proxy - Curl Request, OpenAI Package, Langchain, Langchain JS +Calling a model group + + + + ```shell curl --location 'http://0.0.0.0:8000/chat/completions' \ --header 'Content-Type: application/json' \ --data ' { - "model": "zephyr-alpha", + "model": "gpt-3.5-turbo", "messages": [ { "role": "user", @@ -70,6 +82,63 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \ } ' ``` + + + +```python +import openai +client = openai.OpenAI( + api_key="anything", + base_url="http://0.0.0.0:8000" +) + +# request sent to model set on litellm proxy, `litellm --model` +response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [ + { + "role": "user", + "content": "this is a test request, write a short poem" + } +]) + +print(response) + + +``` + + + + +```python +from langchain.chat_models import ChatOpenAI +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain.schema import HumanMessage, SystemMessage + +chat = ChatOpenAI( + openai_api_base="http://0.0.0.0:8000", # set openai_api_base to the LiteLLM Proxy + model = "gpt-3.5-turbo", + temperature=0.1 +) + +messages = [ + SystemMessage( + content="You are a helpful assistant that im using to make a test request to." + ), + HumanMessage( + content="test from litellm. tell me why it's amazing in 1 sentence" + ), +] +response = chat(messages) + +print(response) +``` + + + + ## Save Model-specific params (API Base, API Keys, Temperature, Headers etc.) You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc. From e615f2670a16baca1aac190420c19ee169687afd Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 12:02:35 -0800 Subject: [PATCH 073/125] (docs) proxy + configs --- docs/my-website/docs/proxy/configs.md | 35 ++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index ed5424b1a..305280bdd 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -68,6 +68,10 @@ Calling a model group +Sends request to model where `model_name=gpt-3.5-turbo` on config.yaml. + +If multiple with `model_name=gpt-3.5-turbo` does [Load Balancing](https://docs.litellm.ai/docs/proxy/load_balancing) + ```shell curl --location 'http://0.0.0.0:8000/chat/completions' \ --header 'Content-Type: application/json' \ @@ -83,6 +87,26 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \ ' ``` + + + +Sends this request to model where `model_name=bedrock-claude-v1` on config.yaml + +```shell +curl --location 'http://0.0.0.0:8000/chat/completions' \ +--header 'Content-Type: application/json' \ +--data ' { + "model": "bedrock-claude-v1", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + } +' +``` + ```python @@ -92,7 +116,7 @@ client = openai.OpenAI( base_url="http://0.0.0.0:8000" ) -# request sent to model set on litellm proxy, `litellm --model` +# Sends request to model where `model_name=gpt-3.5-turbo` on config.yaml. response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [ { "role": "user", @@ -102,6 +126,15 @@ response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [ print(response) +# Sends this request to model where `model_name=bedrock-claude-v1` on config.yaml +response = client.chat.completions.create(model="bedrock-claude-v1", messages = [ + { + "role": "user", + "content": "this is a test request, write a short poem" + } +]) + +print(response) ``` From 88c95ca259f8677b1eb1e2a356e81bf50e87add1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 12:16:33 -0800 Subject: [PATCH 074/125] fix(_redis.py): support additional params for redis --- litellm/_redis.py | 85 +++++++++++++++++++++++++++++++++++ litellm/caching.py | 16 ++++++- litellm/proxy/proxy_server.py | 7 +-- litellm/router.py | 22 +++++---- litellm/utils.py | 34 +++++++------- 5 files changed, 135 insertions(+), 29 deletions(-) create mode 100644 litellm/_redis.py diff --git a/litellm/_redis.py b/litellm/_redis.py new file mode 100644 index 000000000..82e0ab0ec --- /dev/null +++ b/litellm/_redis.py @@ -0,0 +1,85 @@ +# +-----------------------------------------------+ +# | | +# | Give Feedback / Get Help | +# | https://github.com/BerriAI/litellm/issues/new | +# | | +# +-----------------------------------------------+ +# +# Thank you users! We ❤️ you! - Krrish & Ishaan + +# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation +import os +import inspect +import redis, litellm + +def _get_redis_kwargs(): + arg_spec = inspect.getfullargspec(redis.Redis) + + # Only allow primitive arguments + exclude_args = { + "self", + "connection_pool", + "retry", + } + + + include_args = [ + "url" + ] + + available_args = [ + x for x in arg_spec.args if x not in exclude_args + ] + include_args + + return available_args + +def _get_redis_env_kwarg_mapping(): + PREFIX = "REDIS_" + + return { + f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs() + } + + +def _redis_kwargs_from_environment(): + mapping = _get_redis_env_kwarg_mapping() + + return_dict = {} + for k, v in mapping.items(): + value = litellm.get_secret(k, default_value=None) # check os.environ/key vault + if value is not None: + return_dict[v] = value + return return_dict + + +def get_redis_url_from_environment(): + if "REDIS_URL" in os.environ: + return os.environ["REDIS_URL"] + + if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ: + raise ValueError("Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis.") + + if "REDIS_PASSWORD" in os.environ: + redis_password = f":{os.environ['REDIS_PASSWORD']}@" + else: + redis_password = "" + + return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}" + +def get_redis_client(**env_overrides): + redis_kwargs = { + **_redis_kwargs_from_environment(), + **env_overrides, + } + + if "url" in redis_kwargs and redis_kwargs['url'] is not None: + redis_kwargs.pop("host", None) + redis_kwargs.pop("port", None) + redis_kwargs.pop("db", None) + redis_kwargs.pop("password", None) + + return redis.Redis.from_url(**redis_kwargs) + elif "host" not in redis_kwargs or redis_kwargs['host'] is None: + raise ValueError("Either 'host' or 'url' must be specified for redis.") + + return redis.Redis(**redis_kwargs) \ No newline at end of file diff --git a/litellm/caching.py b/litellm/caching.py index d9b94b958..1b6963cc6 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -69,10 +69,22 @@ class InMemoryCache(BaseCache): class RedisCache(BaseCache): - def __init__(self, host, port, password, **kwargs): + def __init__(self, host=None, port=None, password=None, **kwargs): import redis # if users don't provider one, use the default litellm cache - self.redis_client = redis.Redis(host=host, port=port, password=password, **kwargs) + from ._redis import get_redis_client + + redis_kwargs = {} + if host is not None: + redis_kwargs["host"] = host + if port is not None: + redis_kwargs["port"] = port + if password is not None: + redis_kwargs["password"] = password + + redis_kwargs.update(kwargs) + + self.redis_client = get_redis_client(**redis_kwargs) def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ecd4ab8d6..3f94f90b9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -477,9 +477,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): print(f"{blue_color_code}\nSetting Cache on Proxy") from litellm.caching import Cache cache_type = value["type"] - cache_host = litellm.get_secret("REDIS_HOST") - cache_port = litellm.get_secret("REDIS_PORT") - cache_password = litellm.get_secret("REDIS_PASSWORD") + cache_host = litellm.get_secret("REDIS_HOST", None) + cache_port = litellm.get_secret("REDIS_PORT", None) + cache_password = litellm.get_secret("REDIS_PASSWORD", None) # Assuming cache_type, cache_host, cache_port, and cache_password are strings print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}") @@ -488,6 +488,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): print(f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}") print() + ## to pass a complete url, just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables litellm.cache = Cache( type=cache_type, host=cache_host, diff --git a/litellm/router.py b/litellm/router.py index 75fd5afd9..478b5dd23 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -60,10 +60,14 @@ class Router: def __init__(self, model_list: Optional[list] = None, + ## CACHING ## + redis_url: Optional[str] = None, redis_host: Optional[str] = None, redis_port: Optional[int] = None, redis_password: Optional[str] = None, cache_responses: bool = False, + cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py) + ## RELIABILITY ## num_retries: int = 0, timeout: Optional[float] = None, default_litellm_params = {}, # default params for Router.chat.completion.create @@ -107,21 +111,21 @@ class Router: if self.routing_strategy == "least-busy": self._start_health_check_thread() ### CACHING ### + cache_type = "local" # default to an in-memory cache redis_cache = None - if redis_host is not None and redis_port is not None and redis_password is not None: + cache_config = {} + if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None): + cache_type = "redis" cache_config = { - 'type': 'redis', + 'url': redis_url, 'host': redis_host, 'port': redis_port, - 'password': redis_password - } - redis_cache = RedisCache(host=redis_host, port=redis_port, password=redis_password) - else: # use an in-memory cache - cache_config = { - "type": "local" + 'password': redis_password, + **cache_kwargs } + redis_cache = RedisCache(**cache_config) if cache_responses: - litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests + litellm.cache = litellm.Cache(type=cache_type, **cache_config) self.cache_responses = cache_responses self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. ## USAGE TRACKING ## diff --git a/litellm/utils.py b/litellm/utils.py index 84cd20729..c89e690d7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4763,23 +4763,27 @@ def litellm_telemetry(data): ######### Secret Manager ############################ # checks if user has passed in a secret manager client # if passed in then checks the secret there -def get_secret(secret_name: str): +def get_secret(secret_name: str, default_value: Optional[str]=None): if secret_name.startswith("os.environ/"): secret_name = secret_name.replace("os.environ/", "") - if litellm.secret_manager_client is not None: - # TODO: check which secret manager is being used - # currently only supports Infisical - try: - client = litellm.secret_manager_client - if type(client).__module__ + '.' + type(client).__name__ == 'azure.keyvault.secrets._client.SecretClient': # support Azure Secret Client - from azure.keyvault.secrets import SecretClient - secret = retrieved_secret = client.get_secret(secret_name).value - else: # assume the default is infisicial client - secret = client.get_secret(secret_name).secret_value - except: # check if it's in os.environ - secret = os.environ.get(secret_name) - return secret - else: - return os.environ.get(secret_name) + try: + if litellm.secret_manager_client is not None: + try: + client = litellm.secret_manager_client + if type(client).__module__ + '.' + type(client).__name__ == 'azure.keyvault.secrets._client.SecretClient': # support Azure Secret Client - from azure.keyvault.secrets import SecretClient + secret = retrieved_secret = client.get_secret(secret_name).value + else: # assume the default is infisicial client + secret = client.get_secret(secret_name).secret_value + except: # check if it's in os.environ + secret = os.environ.get(secret_name) + return secret + else: + return os.environ.get(secret_name) + except Exception as e: + if default_value is not None: + return default_value + else: + raise e ######## Streaming Class ############################ From 951bcfc043131288e8f1f410c0311feb92e39fb8 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 12:07:18 -0800 Subject: [PATCH 075/125] (fix) router init: raise error Azure API Base not set --- litellm/router.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/router.py b/litellm/router.py index 478b5dd23..70ac4efe6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -880,6 +880,8 @@ class Router: self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}") if "azure" in model_name: + if api_base is None: + raise ValueError("api_base is required for Azure OpenAI. Set it on your config") self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, {str(api_base)}, {api_key}") if api_version is None: api_version = "2023-07-01-preview" From 4f9f53f7dc87ade8a1948b439530f6f64f6c2f9b Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 12:20:12 -0800 Subject: [PATCH 076/125] (docs) proxy: config example with langchain --- docs/my-website/docs/proxy/configs.md | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 305280bdd..490da2294 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -139,7 +139,7 @@ print(response) ``` - + ```python from langchain.chat_models import ChatOpenAI @@ -150,12 +150,6 @@ from langchain.prompts.chat import ( ) from langchain.schema import HumanMessage, SystemMessage -chat = ChatOpenAI( - openai_api_base="http://0.0.0.0:8000", # set openai_api_base to the LiteLLM Proxy - model = "gpt-3.5-turbo", - temperature=0.1 -) - messages = [ SystemMessage( content="You are a helpful assistant that im using to make a test request to." @@ -164,8 +158,25 @@ messages = [ content="test from litellm. tell me why it's amazing in 1 sentence" ), ] -response = chat(messages) +# Sends request to model where `model_name=gpt-3.5-turbo` on config.yaml. +chat = ChatOpenAI( + openai_api_base="http://0.0.0.0:8000", # set openai base to the proxy + model = "gpt-3.5-turbo", + temperature=0.1 +) + +response = chat(messages) +print(response) + +# Sends request to model where `model_name=bedrock-claude-v1` on config.yaml. +claude_chat = ChatOpenAI( + openai_api_base="http://0.0.0.0:8000", # set openai base to the proxy + model = "bedrock-claude-v1", + temperature=0.1 +) + +response = claude_chat(messages) print(response) ``` From d5f67a0a257cd08232960f6eac1f75c0cab085cc Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 12:20:29 -0800 Subject: [PATCH 077/125] (docs) proxy + langchain --- litellm/proxy/tests/test_langchain_request.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/litellm/proxy/tests/test_langchain_request.py b/litellm/proxy/tests/test_langchain_request.py index af6691f3c..2306ffaf7 100644 --- a/litellm/proxy/tests/test_langchain_request.py +++ b/litellm/proxy/tests/test_langchain_request.py @@ -24,5 +24,16 @@ response = chat(messages) print(response) +claude_chat = ChatOpenAI( + openai_api_base="http://0.0.0.0:8000", + model = "claude-v1", + temperature=0.1 +) + +response = claude_chat(messages) + +print(response) + + From ef7795add6680591a60b5cb71c7134c3571a4657 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 12:26:10 -0800 Subject: [PATCH 078/125] fix(utils.py): set text if empty string --- litellm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index c89e690d7..6518d8852 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -414,7 +414,7 @@ class TextChoices(OpenAIObject): else: self.finish_reason = "stop" self.index = index - if text: + if text is not None: self.text = text else: self.text = None From 55b34f969c32ca147a884f31a1a7c32b4d1753ab Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 12:26:36 -0800 Subject: [PATCH 079/125] =?UTF-8?q?bump:=20version=201.10.6=20=E2=86=92=20?= =?UTF-8?q?1.10.7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 029e925a7..ea9117e7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.10.6" +version = "1.10.7" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.10.6" +version = "1.10.7" version_files = [ "pyproject.toml:^version" ] From 2a02fcbffbd311f5eda0789844324f87d3b839b6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 12:38:18 -0800 Subject: [PATCH 080/125] fix(utils.py): map cohere finish reasons --- litellm/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/litellm/utils.py b/litellm/utils.py index 6518d8852..e8e2b57b3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -123,6 +123,15 @@ def map_finish_reason(finish_reason: str): # openai supports 5 stop sequences - # anthropic mapping if finish_reason == "stop_sequence": return "stop" + # cohere mapping - https://docs.cohere.com/reference/generate + elif finish_reason == "COMPLETE": + return "stop" + elif finish_reason == "MAX_TOKENS": + return "length" + elif finish_reason == "ERROR_TOXIC": + return "content_filter" + elif finish_reason == "ERROR": # openai currently doesn't support an 'error' finish reason + return "stop" return finish_reason class FunctionCall(OpenAIObject): From a9b50a12c5c56ff4a83dbb2ea01799eb6a38abf8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 12:38:42 -0800 Subject: [PATCH 081/125] =?UTF-8?q?bump:=20version=201.10.7=20=E2=86=92=20?= =?UTF-8?q?1.10.8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ea9117e7b..3490408c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.10.7" +version = "1.10.8" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.10.7" +version = "1.10.8" version_files = [ "pyproject.toml:^version" ] From 397eefabe114c84f9b82f465168597d7fb8d017a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 12:45:52 -0800 Subject: [PATCH 082/125] test: remove local test --- litellm/proxy/tests/test_langchain_request.py | 59 ++++++++++--------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/litellm/proxy/tests/test_langchain_request.py b/litellm/proxy/tests/test_langchain_request.py index 2306ffaf7..1841b4968 100644 --- a/litellm/proxy/tests/test_langchain_request.py +++ b/litellm/proxy/tests/test_langchain_request.py @@ -1,38 +1,39 @@ -from langchain.chat_models import ChatOpenAI -from langchain.prompts.chat import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - SystemMessagePromptTemplate, -) -from langchain.schema import HumanMessage, SystemMessage +## LOCAL TEST +# from langchain.chat_models import ChatOpenAI +# from langchain.prompts.chat import ( +# ChatPromptTemplate, +# HumanMessagePromptTemplate, +# SystemMessagePromptTemplate, +# ) +# from langchain.schema import HumanMessage, SystemMessage -chat = ChatOpenAI( - openai_api_base="http://0.0.0.0:8000", - model = "gpt-3.5-turbo", - temperature=0.1 -) +# chat = ChatOpenAI( +# openai_api_base="http://0.0.0.0:8000", +# model = "gpt-3.5-turbo", +# temperature=0.1 +# ) -messages = [ - SystemMessage( - content="You are a helpful assistant that im using to make a test request to." - ), - HumanMessage( - content="test from litellm. tell me why it's amazing in 1 sentence" - ), -] -response = chat(messages) +# messages = [ +# SystemMessage( +# content="You are a helpful assistant that im using to make a test request to." +# ), +# HumanMessage( +# content="test from litellm. tell me why it's amazing in 1 sentence" +# ), +# ] +# response = chat(messages) -print(response) +# print(response) -claude_chat = ChatOpenAI( - openai_api_base="http://0.0.0.0:8000", - model = "claude-v1", - temperature=0.1 -) +# claude_chat = ChatOpenAI( +# openai_api_base="http://0.0.0.0:8000", +# model = "claude-v1", +# temperature=0.1 +# ) -response = claude_chat(messages) +# response = claude_chat(messages) -print(response) +# print(response) From d9f083b5f8b71cc0ca35198e5cc40a1d6107696e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 12:46:09 -0800 Subject: [PATCH 083/125] (fix) router: remove misleading print statement --- litellm/router.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 70ac4efe6..0f8a6b948 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -877,8 +877,6 @@ class Router: max_retries_env_name = api_version.replace("os.environ/", "") max_retries = litellm.get_secret(max_retries_env_name) - - self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}") if "azure" in model_name: if api_base is None: raise ValueError("api_base is required for Azure OpenAI. Set it on your config") From 3f84ab04c4d283c457d9813c9e7a4fb2c78cf754 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 12:53:20 -0800 Subject: [PATCH 084/125] (fix) router: Azure Client Init --- litellm/router.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 0f8a6b948..881d89deb 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -936,7 +936,7 @@ class Router: # streaming clients should have diff timeouts model["stream_async_client"] = openai.AsyncAzureOpenAI( api_key=api_key, - base_url=api_base, + azure_endpoint=api_base, api_version=api_version, timeout=stream_timeout, max_retries=max_retries @@ -944,7 +944,7 @@ class Router: model["stream_client"] = openai.AzureOpenAI( api_key=api_key, - base_url=api_base, + azure_endpoint=api_base, api_version=api_version, timeout=stream_timeout, max_retries=max_retries From 5829227d863f4c63bc4cd76f011fe1500706426f Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 12:53:52 -0800 Subject: [PATCH 085/125] (test) router streaming + azure --- litellm/tests/test_router.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index ae9a06c41..4c806291a 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -74,7 +74,8 @@ def test_exception_raising(): def test_reading_key_from_model_list(): - # this tests if the router raises an exception when invalid params are set + # [PROD TEST CASE] + # this tests if the router can read key from model list and make completion call, and completion + stream call. This is 90% of the router use case # DO NOT REMOVE THIS TEST. It's an IMP ONE. Speak to Ishaan, if you are tring to remove this litellm.set_verbose=False import openai @@ -112,6 +113,23 @@ def test_reading_key_from_model_list(): } ] ) + print("\n response", response) + + print("\n Testing streaming response") + response = router.completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello this request will fail" + } + ], + stream=True + ) + for chunk in response: + if chunk is not None: + print(chunk) + print("\n Passed Streaming") os.environ["AZURE_API_KEY"] = old_api_key router.reset() except Exception as e: From 58ab0a3f03a7e48ac02375f292d42bd1d2073655 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 12:54:20 -0800 Subject: [PATCH 086/125] fix(router.py): fix cache init --- litellm/router.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 881d89deb..203cc419e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -116,13 +116,21 @@ class Router: cache_config = {} if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None): cache_type = "redis" - cache_config = { - 'url': redis_url, - 'host': redis_host, - 'port': redis_port, - 'password': redis_password, - **cache_kwargs - } + + if redis_url is not None: + cache_config['url'] = redis_url + + if redis_host is not None: + cache_config['host'] = redis_host + + if redis_port is not None: + cache_config['port'] = redis_port + + if redis_password is not None: + cache_config['password'] = redis_password + + # Add additional key-value pairs from cache_kwargs + cache_config.update(cache_kwargs) redis_cache = RedisCache(**cache_config) if cache_responses: litellm.cache = litellm.Cache(type=cache_type, **cache_config) From e579918dd93d3c33a17077b7cd374b0338eebd35 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 13:22:27 -0800 Subject: [PATCH 087/125] (test) Router: Test Azure acompletion, stream --- litellm/tests/test_router.py | 83 ++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 4c806291a..5f6ef2e21 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -139,6 +139,89 @@ def test_reading_key_from_model_list(): # test_reading_key_from_model_list() + +def test_router_azure_acompletion(): + # [PROD TEST CASE] + # This is 90% of the router use case, makes an acompletion call, acompletion + stream call and verifies it got a response + # DO NOT REMOVE THIS TEST. It's an IMP ONE. Speak to Ishaan, if you are tring to remove this + litellm.set_verbose=False + import openai + try: + print("Router Test Azure - Acompletion, Acompletion with stream") + + # remove api key from env to repro how proxy passes key to router + old_api_key = os.environ["AZURE_API_KEY"] + os.environ.pop("AZURE_API_KEY", None) + + model_list = [ + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": old_api_key, + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE") + }, + "tpm": 240000, + "rpm": 1800 + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": old_api_key, + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE") + }, + "tpm": 240000, + "rpm": 1800 + } + ] + + router = Router(model_list=model_list, + routing_strategy="simple-shuffle", + set_verbose=True + ) # type: ignore + + async def test1(): + + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello this request will fail" + } + ] + ) + print("\n response", response) + asyncio.run(test1()) + + print("\n Testing streaming response") + async def test2(): + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello this request will fail" + } + ], + stream=True + ) + async for chunk in response: + if chunk is not None: + print(chunk) + asyncio.run(test2()) + print("\n Passed Streaming") + os.environ["AZURE_API_KEY"] = old_api_key + router.reset() + except Exception as e: + os.environ["AZURE_API_KEY"] = old_api_key + print(f"FAILED TEST") + pytest.fail(f"Got unexpected exception on router! - {e}") +test_router_azure_acompletion() + ### FUNCTION CALLING def test_function_calling(): From 4e3040b3577affc74cbe577abfab806a57890481 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 13:23:35 -0800 Subject: [PATCH 088/125] (chore) linting fix --- litellm/proxy/queue/celery_app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/proxy/queue/celery_app.py b/litellm/proxy/queue/celery_app.py index 47c9c868f..b9006f13e 100644 --- a/litellm/proxy/queue/celery_app.py +++ b/litellm/proxy/queue/celery_app.py @@ -45,7 +45,7 @@ celery_app.conf.update( @celery_app.task(name='process_job', max_retries=3) def process_job(*args, **kwargs): try: - llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) + llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore response = llm_router.completion(*args, **kwargs) # type: ignore if isinstance(response, litellm.ModelResponse): response = response.model_dump_json() From 1463cc6023800cc2ebebc9cec92d394ed6c728d6 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 13:28:07 -0800 Subject: [PATCH 089/125] (test) router Azure regular chat completion call --- litellm/tests/test_router.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 5f6ef2e21..82ae1646b 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -162,18 +162,16 @@ def test_router_azure_acompletion(): "api_version": os.getenv("AZURE_API_VERSION"), "api_base": os.getenv("AZURE_API_BASE") }, - "tpm": 240000, "rpm": 1800 }, { "model_name": "gpt-3.5-turbo", # openai model name "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": old_api_key, + "model": "azure/gpt-turbo", + "api_key": os.getenv("AZURE_FRANCE_API_KEY"), "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE") + "api_base": "https://openai-france-1234.openai.azure.com" }, - "tpm": 240000, "rpm": 1800 } ] From 63939c0a1112ddf8b0ddf21b7add90cae5b782e3 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 13:30:12 -0800 Subject: [PATCH 090/125] (fix) linting --- litellm/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index 203cc419e..f4b825eb2 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -124,7 +124,7 @@ class Router: cache_config['host'] = redis_host if redis_port is not None: - cache_config['port'] = redis_port + cache_config['port'] = int(redis_port) if redis_password is not None: cache_config['password'] = redis_password From d606a9cb4c23575aef529cbc4df454a0e7d954a8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 13:33:44 -0800 Subject: [PATCH 091/125] refactor(router.py): linting fixes --- litellm/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index f4b825eb2..be6cbd917 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -124,7 +124,7 @@ class Router: cache_config['host'] = redis_host if redis_port is not None: - cache_config['port'] = int(redis_port) + cache_config['port'] = str(redis_port) # type: ignore if redis_password is not None: cache_config['password'] = redis_password From 0d1b42eda5f40fbdfafc496f5d50c754c66f9c87 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 13:35:03 -0800 Subject: [PATCH 092/125] (test) azure - test async + sync embedding --- litellm/tests/test_router.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 82ae1646b..0a14e166f 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -400,7 +400,10 @@ def test_aembedding_on_router(): # test_aembedding_on_router() -def test_azure_aembedding_on_router(): +def test_azure_embedding_on_router(): + """ + [PROD Use Case] - Makes an aembedding call + embedding call + """ litellm.set_verbose = True try: model_list = [ @@ -415,20 +418,28 @@ def test_azure_aembedding_on_router(): "rpm": 10000, }, ] + router = Router(model_list=model_list) async def embedding_call(): - router = Router(model_list=model_list) response = await router.aembedding( model="text-embedding-ada-002", input=["good morning from litellm"] ) print(response) - router.reset() asyncio.run(embedding_call()) + + print("\n Making sync Azure Embedding call\n") + + response = router.embedding( + model="text-embedding-ada-002", + input=["test 2 from litellm. async embedding"] + ) + print(response) + router.reset() except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") -# test_azure_aembedding_on_router() +test_azure_embedding_on_router() def test_bedrock_on_router(): From bc70a6fba8aac08a890d4b3fd8ea5927fa99a55e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 13:59:27 -0800 Subject: [PATCH 093/125] (test) router: add tests for azure completion, acompletion --- litellm/tests/test_router.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 0a14e166f..e64b55959 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -114,6 +114,9 @@ def test_reading_key_from_model_list(): ] ) print("\n response", response) + str_response = response.choices[0].message.content + print("\n str_response", str_response) + assert len(str_response) > 0 print("\n Testing streaming response") response = router.completion( @@ -126,9 +129,13 @@ def test_reading_key_from_model_list(): ], stream=True ) + completed_response = "" for chunk in response: if chunk is not None: print(chunk) + completed_response += chunk.choices[0].delta.content or "" + print("\n completed_response", completed_response) + assert len(completed_response) > 0 print("\n Passed Streaming") os.environ["AZURE_API_KEY"] = old_api_key router.reset() @@ -183,15 +190,18 @@ def test_router_azure_acompletion(): async def test1(): - response = await router.acompletion( + response: litellm.ModelResponse = await router.acompletion( model="gpt-3.5-turbo", messages=[ { "role": "user", - "content": "hello this request will fail" + "content": "hello this request will pass" } ] ) + str_response = response.choices[0].message.content + print("\n str_response", str_response) + assert len(str_response) > 0 print("\n response", response) asyncio.run(test1()) @@ -207,9 +217,13 @@ def test_router_azure_acompletion(): ], stream=True ) + completed_response = "" async for chunk in response: if chunk is not None: print(chunk) + completed_response += chunk.choices[0].delta.content or "" + print("\n completed_response", completed_response) + assert len(completed_response) > 0 asyncio.run(test2()) print("\n Passed Streaming") os.environ["AZURE_API_KEY"] = old_api_key From 3ff57493f437d6c1b07e1baaea4e6b2e996c51e0 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 14:21:37 -0800 Subject: [PATCH 094/125] (test) router: openai async, sync, stream, no stream --- litellm/tests/test_router.py | 89 +++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index e64b55959..6572b9cd1 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -488,4 +488,91 @@ def test_bedrock_on_router(): except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") -# test_bedrock_on_router() \ No newline at end of file +# test_bedrock_on_router() + + +def test_openai_completion_on_router(): + # [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream + # 4 LLM API calls made here. If it fails, add retries. Do not remove this test. + litellm.set_verbose = True + print("\n Testing OpenAI on router\n") + try: + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + }, + }, + ] + router = Router(model_list=model_list) + + async def test(): + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello from litellm test", + } + ] + ) + print(response) + assert len(response.choices[0].message.content) > 0 + + print("\n streaming + acompletion test") + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello from litellm test", + } + ], + stream=True + ) + complete_response = "" + print(response) + async for chunk in response: + print(chunk) + complete_response += chunk.choices[0].delta.content or "" + print("\n complete response: ", complete_response) + assert len(complete_response) > 0 + + asyncio.run(test()) + print("\n Testing Sync completion calls \n") + response = router.completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello from litellm test2", + } + ] + ) + print(response) + assert len(response.choices[0].message.content) > 0 + + print("\n streaming + completion test") + response = router.completion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello from litellm test3", + } + ], + stream=True + ) + complete_response = "" + print(response) + for chunk in response: + print(chunk) + complete_response += chunk.choices[0].delta.content or "" + print("\n complete response: ", complete_response) + assert len(complete_response) > 0 + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") +# test_openai_completion_on_router() \ No newline at end of file From c717ed4d052dfda88aa4db3cf2f1ae6f0e493452 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 14:28:23 -0800 Subject: [PATCH 095/125] (test) router: test async embedding + embedding --- litellm/tests/test_router.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 6572b9cd1..085fcc82a 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -398,20 +398,26 @@ def test_aembedding_on_router(): "rpm": 10000, }, ] - + router = Router(model_list=model_list) async def embedding_call(): - router = Router(model_list=model_list) response = await router.aembedding( model="text-embedding-ada-002", input=["good morning from litellm", "this is another item"], ) print(response) - router.reset() asyncio.run(embedding_call()) + + print("\n Making sync Embedding call\n") + response = router.embedding( + model="text-embedding-ada-002", + input=["good morning from litellm 2"], + ) + print("sync embedding response: ", response) + router.reset() except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") -# test_aembedding_on_router() +test_aembedding_on_router() def test_azure_embedding_on_router(): From 4d7ff1b33b9991dcf38d821266290631d9bcd2dd Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 14:33:28 -0800 Subject: [PATCH 096/125] fix(proxy_server.py): don't override exceptions if they're of type httpexception --- litellm/proxy/proxy_server.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 3f94f90b9..c70d67bf1 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -279,10 +279,13 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap raise Exception(f"Invalid token") except Exception as e: print(f"An exception occurred - {traceback.format_exc()}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="invalid user key", - ) + if isinstance(e, HTTPException): + raise e + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="invalid user key", + ) def prisma_setup(database_url: Optional[str]): global prisma_client From 68ca2a28d41701218b5111e7e3f05eae3e3dd370 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 15:08:00 -0800 Subject: [PATCH 097/125] docs: adds redis url to router + proxy docs --- docs/my-website/docs/caching/redis_cache.md | 10 +++++++++- docs/my-website/docs/proxy/caching.md | 14 ++++++++++++-- docs/my-website/docs/routing.md | 10 ++++++++++ litellm/tests/test_caching.py | 2 +- 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/docs/my-website/docs/caching/redis_cache.md b/docs/my-website/docs/caching/redis_cache.md index f0fcc6952..521c4d00f 100644 --- a/docs/my-website/docs/caching/redis_cache.md +++ b/docs/my-website/docs/caching/redis_cache.md @@ -1,11 +1,14 @@ # Redis Cache + +[**See Code**](https://github.com/BerriAI/litellm/blob/4d7ff1b33b9991dcf38d821266290631d9bcd2dd/litellm/caching.py#L71) + ### Pre-requisites Install redis ``` pip install redis ``` For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/ -### Usage +### Quick Start ```python import litellm from litellm import completion @@ -55,6 +58,11 @@ litellm.cache = cache # set litellm.cache to your cache ### Detecting Cached Responses For resposes that were returned as cache hit, the response includes a param `cache` = True +:::info + +Only valid for OpenAI <= 0.28.1 [Let us know if you still need this](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=bug&projects=&template=bug_report.yml&title=%5BBug%5D%3A+) +::: + Example response with cache hit ```python { diff --git a/docs/my-website/docs/proxy/caching.md b/docs/my-website/docs/proxy/caching.md index d052102db..56e4b4c1c 100644 --- a/docs/my-website/docs/proxy/caching.md +++ b/docs/my-website/docs/proxy/caching.md @@ -12,17 +12,27 @@ model_list: litellm_settings: set_verbose: True cache: # init cache - type: redis # tell litellm to use redis caching + type: redis # tell litellm to use redis caching (Also: `pip install redis`) ``` #### Step 2: Add Redis Credentials to .env -LiteLLM requires the following REDIS credentials in your env to enable caching +Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching. ```shell + REDIS_URL = "" # REDIS_URL='redis://username:password@hostname:port/database' + ## OR ## REDIS_HOST = "" # REDIS_HOST='redis-18841.c274.us-east-1-3.ec2.cloud.redislabs.com' REDIS_PORT = "" # REDIS_PORT='18841' REDIS_PASSWORD = "" # REDIS_PASSWORD='liteLlmIsAmazing' ``` + +**Additional kwargs** +You can pass in any additional redis.Redis arg, by storing the variable + value in your os environment, like this: +```shell +REDIS_ = "" +``` + +[**See how it's read from the environment**](https://github.com/BerriAI/litellm/blob/4d7ff1b33b9991dcf38d821266290631d9bcd2dd/litellm/_redis.py#L40) #### Step 3: Run proxy with config ```shell $ litellm --config /path/to/config.yaml diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md index c0ca93b25..3f55ae28e 100644 --- a/docs/my-website/docs/routing.md +++ b/docs/my-website/docs/routing.md @@ -356,6 +356,16 @@ router = Router(model_list=model_list, print(response) ``` + +**Pass in Redis URL, additional kwargs** +```python +router = Router(model_list: Optional[list] = None, + ## CACHING ## + redis_url=os.getenv("REDIS_URL")", + cache_kwargs= {}, # additional kwargs to pass to RedisCache (see caching.py) + cache_responses=True) +``` + #### Default litellm.completion/embedding params You can also set default params for litellm completion/embedding calls. Here's how to do that: diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 713f97b3e..a0980e9de 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -36,7 +36,7 @@ def test_caching_v2(): # test in memory cache print(f"error occurred: {traceback.format_exc()}") pytest.fail(f"Error occurred: {e}") -# test_caching_v2() +test_caching_v2() From c4bda13820ee6c25ca9a02ad596e495b41f9981e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 15:32:15 -0800 Subject: [PATCH 098/125] (fix) sagemaker Llama-2 70b --- litellm/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index e8e2b57b3..f419ebcd3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2206,7 +2206,9 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens elif custom_llm_provider == "sagemaker": - if "llama-2" in model: + if "llama-2" in model.lower() or ( + "llama" in model.lower() and "2" in model.lower() # some combination of llama and "2" should exist + ): # jumpstart can also send "Llama-2-70b-chat-hf-48xlarge" # llama-2 models on sagemaker support the following args """ max_new_tokens: Model generates text until the output length (excluding the input context length) reaches max_new_tokens. If specified, it must be a positive integer. From 09c2c1610def9dc713d48a23d138436b6a8b7fb8 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 15:37:39 -0800 Subject: [PATCH 099/125] =?UTF-8?q?bump:=20version=201.10.8=20=E2=86=92=20?= =?UTF-8?q?1.10.9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3490408c2..e6a5e9dc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.10.8" +version = "1.10.9" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.10.8" +version = "1.10.9" version_files = [ "pyproject.toml:^version" ] From b4c78c7b9eb7657f2a793ccd681c527963cce9ca Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 16:04:43 -0800 Subject: [PATCH 100/125] fix(utils.py): support sagemaker llama2 custom endpoints --- litellm/llms/sagemaker.py | 30 ++++++++++------- litellm/main.py | 1 + litellm/tests/test_completion.py | 10 ++++-- litellm/utils.py | 57 +++++++++++++++----------------- 4 files changed, 53 insertions(+), 45 deletions(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 51be1f53d..1ee43ec2e 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -9,6 +9,7 @@ from litellm.utils import ModelResponse, get_secret, Usage import sys from copy import deepcopy import httpx +from .prompt_templates.factory import prompt_factory, custom_prompt class SagemakerError(Exception): def __init__(self, status_code, message): @@ -61,6 +62,7 @@ def completion( print_verbose: Callable, encoding, logging_obj, + custom_prompt_dict={}, optional_params=None, litellm_params=None, logger_fn=None, @@ -107,19 +109,23 @@ def completion( inference_params[k] = v model = model - prompt = "" - for message in messages: - if "role" in message: - if message["role"] == "user": - prompt += ( - f"{message['content']}" - ) + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages + ) + else: + hf_model_name = model + if "jumpstart-dft-meta-textgeneration-llama" in model: # llama2 model + if model.endswith("-f") or "-f-" in model: # sagemaker default for a chat model + hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model else: - prompt += ( - f"{message['content']}" - ) - else: - prompt += f"{message['content']}" + hf_model_name = "meta-llama/Llama-2-7b" # apply the normal prompt template + prompt = prompt_factory(model=hf_model_name, messages=messages) data = json.dumps({ "inputs": prompt, diff --git a/litellm/main.py b/litellm/main.py index 5c421a351..e0c7cf3b1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1166,6 +1166,7 @@ def completion( print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, + custom_prompt_dict=custom_prompt_dict, logger_fn=logger_fn, encoding=encoding, logging_obj=logging diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 15b547a85..3c4a9aa40 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1048,20 +1048,24 @@ def test_completion_sagemaker(): def test_completion_chat_sagemaker(): try: + messages = [{"role": "user", "content": "Hey, how's it going?"}] print("testing sagemaker") litellm.set_verbose=True response = completion( model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f", messages=messages, + max_tokens=100, stream=True, ) # Add any assertions here to check the response - print(response) + complete_response = "" for chunk in response: - print(chunk) + complete_response += chunk.choices[0].delta.content or "" + print(f"complete_response: {complete_response}") + assert len(complete_response) > 0 except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_chat_sagemaker() +test_completion_chat_sagemaker() def test_completion_bedrock_titan(): try: diff --git a/litellm/utils.py b/litellm/utils.py index f419ebcd3..814735f88 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2206,32 +2206,31 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens elif custom_llm_provider == "sagemaker": - if "llama-2" in model.lower() or ( - "llama" in model.lower() and "2" in model.lower() # some combination of llama and "2" should exist - ): # jumpstart can also send "Llama-2-70b-chat-hf-48xlarge" - # llama-2 models on sagemaker support the following args - """ - max_new_tokens: Model generates text until the output length (excluding the input context length) reaches max_new_tokens. If specified, it must be a positive integer. - temperature: Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If temperature -> 0, it results in greedy decoding. If specified, it must be a positive float. - top_p: In each step of text generation, sample from the smallest possible set of words with cumulative probability top_p. If specified, it must be a float between 0 and 1. - return_full_text: If True, input text will be part of the output generated text. If specified, it must be boolean. The default value for it is False. - """ - ## check if unsupported param passed in - supported_params = ["temperature", "max_tokens", "stream"] - _check_valid_arg(supported_params=supported_params) - - if max_tokens is not None: - optional_params["max_new_tokens"] = max_tokens - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if stream: - optional_params["stream"] = stream - else: - ## check if unsupported param passed in - supported_params = [] - _check_valid_arg(supported_params=supported_params) + ## check if unsupported param passed in + supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] + _check_valid_arg(supported_params=supported_params) + # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None + if temperature is not None: + if temperature == 0.0 or temperature == 0: + # hugging face exception raised when temp==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive + temperature = 0.01 + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["top_p"] = top_p + if n is not None: + optional_params["best_of"] = n + optional_params["do_sample"] = True # Need to sample if you want best of for hf inference endpoints + if stream is not None: + optional_params["stream"] = stream + if stop is not None: + optional_params["stop"] = stop + if max_tokens is not None: + # HF TGI raises the following exception when max_new_tokens==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive + if max_tokens == 0: + max_tokens = 1 + optional_params["max_new_tokens"] = max_tokens elif custom_llm_provider == "bedrock": if "ai21" in model: supported_params = ["max_tokens", "temperature", "top_p", "stream"] @@ -5270,11 +5269,9 @@ class CustomStreamWrapper: else: model_response.choices[0].finish_reason = "stop" self.sent_last_chunk = True - chunk_size = 30 - new_chunk = self.completion_stream[:chunk_size] + new_chunk = self.completion_stream completion_obj["content"] = new_chunk - self.completion_stream = self.completion_stream[chunk_size:] - time.sleep(0.05) + self.completion_stream = self.completion_stream[len(self.completion_stream):] elif self.custom_llm_provider == "petals": if len(self.completion_stream)==0: if self.sent_last_chunk: From d2dab362dfe357f1c2e84a3c5ebd97f3d5c18c0c Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 16:08:12 -0800 Subject: [PATCH 101/125] (fix) proxy debugging display Init API key --- litellm/proxy/proxy_server.py | 4 ++-- litellm/router.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c70d67bf1..4cdb4d082 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -609,6 +609,8 @@ def initialize( generate_feedback_box() user_model = model user_debug = debug + if debug==True: # this needs to be first, so users can see Router init debugg + litellm.set_verbose = True dynamic_config = {"general": {}, user_model: {}} if config: llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config) @@ -646,8 +648,6 @@ def initialize( if max_budget: # litellm-specific param litellm.max_budget = max_budget dynamic_config["general"]["max_budget"] = max_budget - if debug==True: # litellm-specific param - litellm.set_verbose = True if use_queue: celery_setup(use_queue=use_queue) if experimental: diff --git a/litellm/router.py b/litellm/router.py index be6cbd917..4f344e18f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -959,7 +959,7 @@ class Router: ) else: - self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}") + self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}, {api_key}") model["async_client"] = openai.AsyncOpenAI( api_key=api_key, base_url=api_base, From 01fc7f1931c2297d79d9f12654e58623a196c627 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 16:18:13 -0800 Subject: [PATCH 102/125] fix(sagemaker.py): add support for amazon neuron llama models --- litellm/llms/sagemaker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 1ee43ec2e..a575bf9d1 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -120,8 +120,8 @@ def completion( ) else: hf_model_name = model - if "jumpstart-dft-meta-textgeneration-llama" in model: # llama2 model - if model.endswith("-f") or "-f-" in model: # sagemaker default for a chat model + if "jumpstart-dft-meta-textgeneration-llama" in model or "meta-textgenerationneuron-llama-2-7b" in model: # llama2 model + if model.endswith("-f") or "-f-" in model or "chat" in model: # sagemaker default for a chat model hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model else: hf_model_name = "meta-llama/Llama-2-7b" # apply the normal prompt template From 3c60682eb4096c5486b639c7a413ba681894d439 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 16:19:28 -0800 Subject: [PATCH 103/125] fix(sagemaker.py): accept all amazon neuron llama2 models --- litellm/llms/sagemaker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index a575bf9d1..3e23ae415 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -120,7 +120,7 @@ def completion( ) else: hf_model_name = model - if "jumpstart-dft-meta-textgeneration-llama" in model or "meta-textgenerationneuron-llama-2-7b" in model: # llama2 model + if "jumpstart-dft-meta-textgeneration-llama" in model or "meta-textgenerationneuron-llama-2" in model: # llama2 model if model.endswith("-f") or "-f-" in model or "chat" in model: # sagemaker default for a chat model hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model else: From a38504ff1b85fe2acf5ce27c72c450cdede12bf2 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 16:23:03 -0800 Subject: [PATCH 104/125] fix(sagemaker.py): fix meta llama model name for sagemaker custom deployment --- litellm/llms/sagemaker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 3e23ae415..3ddfb4c60 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -120,7 +120,7 @@ def completion( ) else: hf_model_name = model - if "jumpstart-dft-meta-textgeneration-llama" in model or "meta-textgenerationneuron-llama-2" in model: # llama2 model + if "meta-textgeneration-llama-2" in model or "meta-textgenerationneuron-llama-2" in model: # llama2 model if model.endswith("-f") or "-f-" in model or "chat" in model: # sagemaker default for a chat model hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model else: From 54d8a9df3f7283d02e1ad132d5f99fa519871c75 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 16:31:59 -0800 Subject: [PATCH 105/125] fix(sagemaker.py): enable passing hf model name for prompt template --- litellm/llms/sagemaker.py | 8 ++------ litellm/main.py | 5 ++++- litellm/tests/test_completion.py | 2 ++ 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 3ddfb4c60..e1c9ccdc8 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -63,6 +63,7 @@ def completion( encoding, logging_obj, custom_prompt_dict={}, + hf_model_name=None, optional_params=None, litellm_params=None, logger_fn=None, @@ -119,12 +120,7 @@ def completion( messages=messages ) else: - hf_model_name = model - if "meta-textgeneration-llama-2" in model or "meta-textgenerationneuron-llama-2" in model: # llama2 model - if model.endswith("-f") or "-f-" in model or "chat" in model: # sagemaker default for a chat model - hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model - else: - hf_model_name = "meta-llama/Llama-2-7b" # apply the normal prompt template + hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) prompt = prompt_factory(model=hf_model_name, messages=messages) data = json.dumps({ diff --git a/litellm/main.py b/litellm/main.py index e0c7cf3b1..f265d4653 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -341,11 +341,13 @@ def completion( final_prompt_value = kwargs.get("final_prompt_value", None) bos_token = kwargs.get("bos_token", None) eos_token = kwargs.get("eos_token", None) + hf_model_name = kwargs.get("hf_model_name", None) + ### ASYNC CALLS ### acompletion = kwargs.get("acompletion", False) client = kwargs.get("client", None) ######## end of unpacking kwargs ########### openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] - litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token"] + litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name"] default_params = openai_params + litellm_params non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider if mock_response: @@ -1167,6 +1169,7 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, custom_prompt_dict=custom_prompt_dict, + hf_model_name=hf_model_name, logger_fn=logger_fn, encoding=encoding, logging_obj=logging diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 3c4a9aa40..d0cda9335 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1039,6 +1039,7 @@ def test_completion_sagemaker(): messages=messages, temperature=0.2, max_tokens=80, + hf_model_name="meta-llama/Llama-2-7b", ) # Add any assertions here to check the response print(response) @@ -1056,6 +1057,7 @@ def test_completion_chat_sagemaker(): messages=messages, max_tokens=100, stream=True, + hf_model_name="meta-llama/Llama-2-7b-chat-hf", ) # Add any assertions here to check the response complete_response = "" From 88845dddb1c25ebeda244ae2526fd9ba8df21944 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 16:42:19 -0800 Subject: [PATCH 106/125] fix(sagemaker.py): bring back llama2 templating for sagemaker --- litellm/llms/sagemaker.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index e1c9ccdc8..ca71461cf 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -120,6 +120,12 @@ def completion( messages=messages ) else: + if hf_model_name is None: + if "llama2" in model.lower(): # llama2 model + if "chat" in model.lower(): + hf_model_name = "meta-llama/Llama-2-7b-chat-hf" + else: + hf_model_name = "meta-llama/Llama-2-7b" hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) prompt = prompt_factory(model=hf_model_name, messages=messages) From ff949490de856e30b3ec502c09065d2f25776e12 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 16:56:18 -0800 Subject: [PATCH 107/125] docs(input.md): add hf_model_name to docs --- docs/my-website/docs/completion/input.md | 21 ++++++++++++++++++++- litellm/llms/sagemaker.py | 15 +++++++++------ litellm/tests/test_completion.py | 3 ++- litellm/utils.py | 8 ++++++++ 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index 7902275ab..047be5395 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -40,7 +40,7 @@ This list is constantly being updated. |AI21| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | |VertexAI| ✅ | ✅ | | ✅ | | | | | | | |Bedrock| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | -|Sagemaker| ✅ | ✅ (only `jumpstart llama2`) | | ✅ | | | | | | | +|Sagemaker| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |AlephAlpha| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |Palm| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | @@ -185,6 +185,25 @@ def completion( - `metadata`: *dict (optional)* - Any additional data you want to be logged when the call is made (sent to logging integrations, eg. promptlayer and accessible via custom callback function) +**CUSTOM MODEL COST** +- `input_cost_per_token`: *float (optional)* - The cost per input token for the completion call + +- `output_cost_per_token`: *float (optional)* - The cost per output token for the completion call + +**CUSTOM PROMPT TEMPLATE** (See [prompt formatting for more info](./prompt_formatting.md#format-prompt-yourself)) +- `initial_prompt_value`: *string (optional)* - Initial string applied at the start of the input messages + +- `roles`: *dict (optional)* - Dictionary specifying how to format the prompt based on the role + message passed in via `messages`. + +- `final_prompt_value`: *string (optional)* - Final string applied at the end of the input messages + +- `bos_token`: *string (optional)* - Initial string applied at the start of a sequence + +- `eos_token`: *string (optional)* - Initial string applied at the end of a sequence + +- `hf_model_name`: *string (optional)* - [Sagemaker Only] The corresponding huggingface name of the model, used to pull the right chat template for the model. + + ## Provider-specific Params Providers might offer params not supported by OpenAI (e.g. top_k). You can pass those in 2 ways: - via completion(): We'll pass the non-openai param, straight to the provider as part of the request body. diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index ca71461cf..2482c5457 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -149,12 +149,15 @@ def completion( additional_args={"complete_input_dict": data, "request_str": request_str}, ) ## COMPLETION CALL - response = client.invoke_endpoint( - EndpointName=model, - ContentType="application/json", - Body=data, - CustomAttributes="accept_eula=true", - ) + try: + response = client.invoke_endpoint( + EndpointName=model, + ContentType="application/json", + Body=data, + CustomAttributes="accept_eula=true", + ) + except Exception as e: + raise SagemakerError(status_code=500, message=f"{str(e)}") response = response["Body"].read().decode("utf8") ## LOGGING logging_obj.post_call( diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index d0cda9335..0f0bff1cc 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1053,10 +1053,11 @@ def test_completion_chat_sagemaker(): print("testing sagemaker") litellm.set_verbose=True response = completion( - model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f", + model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-chat", messages=messages, max_tokens=100, stream=True, + n=2, hf_model_name="meta-llama/Llama-2-7b-chat-hf", ) # Add any assertions here to check the response diff --git a/litellm/utils.py b/litellm/utils.py index 814735f88..9e93f6b64 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4065,6 +4065,14 @@ def exception_type( llm_provider="sagemaker", response=original_exception.response ) + elif "Input validation error: `best_of` must be > 0 and <= 2" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", + model=model, + llm_provider="sagemaker", + response=original_exception.response + ) elif custom_llm_provider == "vertex_ai": if "Vertex AI API has not been used in project" in error_str or "Unable to find your project" in error_str: exception_mapping_worked = True From bb6a1968b3eb5cbe9bf8b4a5a83e5f70f6f4e68c Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 16:51:50 -0800 Subject: [PATCH 108/125] (fix) router - allow user to call 1 deployment --- litellm/router.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/litellm/router.py b/litellm/router.py index 4f344e18f..b63f45f3d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1064,6 +1064,13 @@ class Router: healthy_deployments.remove(deployment) self.print_verbose(f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}") if len(healthy_deployments) == 0: + # users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment + for deployment in self.model_list: + cleaned_model = litellm.utils.remove_model_id(deployment.get("litellm_params").get("model")) + if cleaned_model == model: + # User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2 + # return the first deployment where the `model` matches the specificed deployment name + return deployment raise ValueError("No models available") if litellm.model_alias_map and model in litellm.model_alias_map: model = litellm.model_alias_map[ From 703a575a5d1cffccca905977ac30a9c7a7cfb82a Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 16:56:25 -0800 Subject: [PATCH 109/125] (test) call 1 deployment on router --- litellm/tests/test_router.py | 69 ++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 085fcc82a..e0e232748 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -145,6 +145,75 @@ def test_reading_key_from_model_list(): pytest.fail(f"Got unexpected exception on router! - {e}") # test_reading_key_from_model_list() +def test_call_one_endpoint(): + # [PROD TEST CASE] + # user passes one deployment they want to call on the router, we call the specified one + # this test makes a completion calls azure/chatgpt-v-2, it should work + try: + print("Testing calling a specific deployment") + old_api_key = os.environ["AZURE_API_KEY"] + + model_list = [ + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": old_api_key, + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE") + }, + "tpm": 240000, + "rpm": 1800 + }, + { + "model_name": "claude-v1", + "litellm_params": { + "model": "bedrock/anthropic.claude-instant-v1", + }, + "tpm": 100000, + "rpm": 10000, + } + ] + + router = Router(model_list=model_list, + routing_strategy="simple-shuffle", + set_verbose=True, + num_retries=1) # type: ignore + old_api_base = os.environ.pop("AZURE_API_BASE", None) + + + response = router.completion( + model="azure/chatgpt-v-2", + messages=[ + { + "role": "user", + "content": "hello this request will pass" + } + ], + ) + print("\n response", response) + + + response = router.completion( + model="bedrock/anthropic.claude-instant-v1", + messages=[ + { + "role": "user", + "content": "hello this request will pass" + } + ], + ) + + print("\n response", response) + + os.environ["AZURE_API_BASE"] = old_api_base + os.environ["AZURE_API_KEY"] = old_api_key + except Exception as e: + print(f"FAILED TEST") + pytest.fail(f"Got unexpected exception on router! - {e}") + +test_call_one_endpoint() + def test_router_azure_acompletion(): From 1addaecf48c6ace4705938bd2ca4e93dc8f4f42b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 16:58:12 -0800 Subject: [PATCH 110/125] docs(aws_sagemaker.md): add hf_model_name to sagemaker docs --- .../docs/providers/aws_sagemaker.md | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/docs/my-website/docs/providers/aws_sagemaker.md b/docs/my-website/docs/providers/aws_sagemaker.md index 264d6d9f2..e9160139d 100644 --- a/docs/my-website/docs/providers/aws_sagemaker.md +++ b/docs/my-website/docs/providers/aws_sagemaker.md @@ -42,6 +42,27 @@ response = completion( ) ``` +### Specifying HF Model Name +To apply the correct prompt template for your sagemaker deployment, pass in it's hf model name as well. + +```python +import os +from litellm import completion + +os.environ["AWS_ACCESS_KEY_ID"] = "" +os.environ["AWS_SECRET_ACCESS_KEY"] = "" +os.environ["AWS_REGION_NAME"] = "" + +response = completion( + model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b", + messages=messages, + temperature=0.2, + max_tokens=80, + hf_model_name="meta-llama/Llama-2-7b", + ) +``` + + ### Usage - Streaming Sagemaker currently does not support streaming - LiteLLM fakes streaming by returning chunks of the response string From e4fae5a3e8946ef7d60dc24f814839a6a7472180 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 16:59:37 -0800 Subject: [PATCH 111/125] docs(aws_sagemaker.md): support for all huggingface/jumpstart modelsn --- docs/my-website/docs/providers/aws_sagemaker.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/my-website/docs/providers/aws_sagemaker.md b/docs/my-website/docs/providers/aws_sagemaker.md index e9160139d..606268ad1 100644 --- a/docs/my-website/docs/providers/aws_sagemaker.md +++ b/docs/my-website/docs/providers/aws_sagemaker.md @@ -1,5 +1,5 @@ # AWS Sagemaker -LiteLLM supports Llama2 on Sagemaker +LiteLLM supports All Sagemaker Huggingface Jumpstart Models ### API KEYS ```python From a532cf14aefa9a4dc77457d832b137f96364c6b6 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 17:18:46 -0800 Subject: [PATCH 112/125] (feat) router - track original deployment names --- litellm/router.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/router.py b/litellm/router.py index b63f45f3d..775d1c90b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -78,6 +78,7 @@ class Router: routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None: self.set_verbose = set_verbose + self.deployment_names = [] # names of models under litellm_params. ex. azure/chatgpt-v-2 if model_list: self.set_model_list(model_list) self.healthy_deployments: List = self.model_list @@ -990,6 +991,7 @@ class Router: ) ############ End of initializing Clients for OpenAI/Azure ################### + self.deployment_names.append(model["litellm_params"]["model"]) model_id = "" for key in model["litellm_params"]: if key != "api_key": From 3af4f7fb0f16337fb77412886f8206ab4ace1c44 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 17:19:39 -0800 Subject: [PATCH 113/125] (fix) proxy: /chat/cmp - check 1 deployment --- litellm/proxy/proxy_server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4cdb4d082..088181028 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -863,7 +863,9 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] if llm_router is not None and data["model"] in router_model_names: # model in router model list response = await llm_router.acompletion(**data) - else: + elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router + response = await llm_router.acompletion(**data) + else: # router is not set response = await litellm.acompletion(**data) if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses return StreamingResponse(async_data_generator(response), media_type='text/event-stream') From 4d5313343b3d02382bf0628faf991ed7ceb2253b Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 17:22:07 -0800 Subject: [PATCH 114/125] (feat) proxy /embedding check 1 deploy call --- litellm/proxy/proxy_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 088181028..80f19986b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -925,6 +925,8 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] if llm_router is not None and data["model"] in router_model_names: # model in router model list response = await llm_router.aembedding(**data) + elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router + response = await llm_router.aembedding(**data) else: response = await litellm.aembedding(**data) background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL From e788a34da467a8504ed116d0fabf1c7d59d4fa04 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 17:26:03 -0800 Subject: [PATCH 115/125] (chore) linting fix --- litellm/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index 775d1c90b..941ad1318 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -78,7 +78,7 @@ class Router: routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None: self.set_verbose = set_verbose - self.deployment_names = [] # names of models under litellm_params. ex. azure/chatgpt-v-2 + self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2 if model_list: self.set_model_list(model_list) self.healthy_deployments: List = self.model_list From 1fa9ddd739fdc56b36dfe675272037093688e976 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 17:29:09 -0800 Subject: [PATCH 116/125] (chore) linting fix --- litellm/tests/test_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index e0e232748..61ec8aba9 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -259,7 +259,7 @@ def test_router_azure_acompletion(): async def test1(): - response: litellm.ModelResponse = await router.acompletion( + response = await router.acompletion( model="gpt-3.5-turbo", messages=[ { From 0eccc1b1f8c685278db2f1bedb2491ddc8039280 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 17:35:35 -0800 Subject: [PATCH 117/125] (test) router: call 1 deployment --- litellm/tests/test_router.py | 68 ++++++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 61ec8aba9..ca84f144c 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -172,7 +172,17 @@ def test_call_one_endpoint(): }, "tpm": 100000, "rpm": 10000, - } + }, + { + "model_name": "text-embedding-ada-002", + "litellm_params": { + "model": "azure/azure-embedding-model", + "api_key":os.environ['AZURE_API_KEY'], + "api_base": os.environ['AZURE_API_BASE'] + }, + "tpm": 100000, + "rpm": 10000, + }, ] router = Router(model_list=model_list, @@ -181,38 +191,50 @@ def test_call_one_endpoint(): num_retries=1) # type: ignore old_api_base = os.environ.pop("AZURE_API_BASE", None) + async def call_azure_completion(): + response = await router.acompletion( + model="azure/chatgpt-v-2", + messages=[ + { + "role": "user", + "content": "hello this request will pass" + } + ], + ) + print("\n response", response) - response = router.completion( - model="azure/chatgpt-v-2", - messages=[ - { - "role": "user", - "content": "hello this request will pass" - } - ], - ) - print("\n response", response) + async def call_bedrock_claude(): + response = await router.acompletion( + model="bedrock/anthropic.claude-instant-v1", + messages=[ + { + "role": "user", + "content": "hello this request will pass" + } + ], + ) + print("\n response", response) + + async def call_azure_embedding(): + response = await router.aembedding( + model="azure/azure-embedding-model", + input = ["good morning from litellm"] + ) - response = router.completion( - model="bedrock/anthropic.claude-instant-v1", - messages=[ - { - "role": "user", - "content": "hello this request will pass" - } - ], - ) - - print("\n response", response) + print("\n response", response) + asyncio.run(call_azure_completion()) + asyncio.run(call_bedrock_claude()) + asyncio.run(call_azure_embedding()) + os.environ["AZURE_API_BASE"] = old_api_base os.environ["AZURE_API_KEY"] = old_api_key except Exception as e: print(f"FAILED TEST") pytest.fail(f"Got unexpected exception on router! - {e}") -test_call_one_endpoint() +# test_call_one_endpoint() From 648d41c96f3178405ac1252a9c195737f3025794 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 17:47:37 -0800 Subject: [PATCH 118/125] fix(sagemaker.py): prompt templating fixes --- litellm/llms/sagemaker.py | 8 ++++---- litellm/tests/test_completion.py | 28 ++++++++++++++++++++-------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 2482c5457..cb5b56bdd 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -121,10 +121,10 @@ def completion( ) else: if hf_model_name is None: - if "llama2" in model.lower(): # llama2 model - if "chat" in model.lower(): + if "llama-2" in model.lower(): # llama-2 model + if "chat" in model.lower(): # apply llama2 chat template hf_model_name = "meta-llama/Llama-2-7b-chat-hf" - else: + else: # apply regular llama2 template hf_model_name = "meta-llama/Llama-2-7b" hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) prompt = prompt_factory(model=hf_model_name, messages=messages) @@ -146,7 +146,7 @@ def completion( logging_obj.pre_call( input=prompt, api_key="", - additional_args={"complete_input_dict": data, "request_str": request_str}, + additional_args={"complete_input_dict": data, "request_str": request_str, "hf_model_name": hf_model_name}, ) ## COMPLETION CALL try: diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 0f0bff1cc..69aff761c 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1035,30 +1035,27 @@ def test_completion_sagemaker(): print("testing sagemaker") litellm.set_verbose=True response = completion( - model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b", + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", messages=messages, temperature=0.2, max_tokens=80, - hf_model_name="meta-llama/Llama-2-7b", ) # Add any assertions here to check the response print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_sagemaker() +test_completion_sagemaker() def test_completion_chat_sagemaker(): try: messages = [{"role": "user", "content": "Hey, how's it going?"}] - print("testing sagemaker") litellm.set_verbose=True response = completion( - model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-chat", + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", messages=messages, max_tokens=100, + temperature=0.7, stream=True, - n=2, - hf_model_name="meta-llama/Llama-2-7b-chat-hf", ) # Add any assertions here to check the response complete_response = "" @@ -1068,8 +1065,23 @@ def test_completion_chat_sagemaker(): assert len(complete_response) > 0 except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_chat_sagemaker() +# test_completion_chat_sagemaker() +def test_completion_chat_sagemaker_mistral(): + try: + messages = [{"role": "user", "content": "Hey, how's it going?"}] + + response = completion( + model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-instruct", + messages=messages, + max_tokens=100, + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"An error occurred: {str(e)}") + +# test_completion_chat_sagemaker_mistral() def test_completion_bedrock_titan(): try: response = completion( From cb52e3347ed3dae7913c5569a66a2acb3988a5e6 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 17:59:57 -0800 Subject: [PATCH 119/125] (fix) proxy: make yaml load print_verbose --- litellm/proxy/proxy_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 80f19986b..a90530068 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -431,7 +431,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): for model in printed_yaml["model_list"]: model["litellm_params"].pop("api_key", None) - print(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}") + print_verbose(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}") ## ENVIRONMENT VARIABLES environment_variables = config.get('environment_variables', None) From 39bb972168fb62436593c8ffa2537d368c112808 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 18:01:58 -0800 Subject: [PATCH 120/125] =?UTF-8?q?bump:=20version=201.10.9=20=E2=86=92=20?= =?UTF-8?q?1.10.10?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e6a5e9dc2..3acaf6047 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.10.9" +version = "1.10.10" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.10.9" +version = "1.10.10" version_files = [ "pyproject.toml:^version" ] From 155e99b9a38ad51bbc1c383882e5e80148ff9690 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 18:04:05 -0800 Subject: [PATCH 121/125] (fix) prox cli: remove deprecated param --- litellm/proxy/proxy_cli.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index a76a49b2c..6c17af95a 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -74,7 +74,6 @@ def is_port_in_use(port): @click.option('--drop_params', is_flag=True, help='Drop any unmapped params') @click.option('--add_function_to_prompt', is_flag=True, help='If function passed but unsupported, pass it as prompt') @click.option('--config', '-c', default=None, help='Configure Litellm') -@click.option('--file', '-f', help='Path to config file') @click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`') @click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`') @click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.') @@ -83,7 +82,7 @@ def is_port_in_use(port): @click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response') @click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with') @click.option('--local', is_flag=True, default=False, help='for local debugging') -def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, file, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health): +def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health): global feature_telemetry args = locals() if local: From 56acded998279b7dfe5e4130fd65315ff18c8873 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 18:07:27 -0800 Subject: [PATCH 122/125] (router) better debugging using config.yaml --- litellm/router.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 941ad1318..9d6ab5a10 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -889,7 +889,7 @@ class Router: if "azure" in model_name: if api_base is None: raise ValueError("api_base is required for Azure OpenAI. Set it on your config") - self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, {str(api_base)}, {api_key}") + self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}") if api_version is None: api_version = "2023-07-01-preview" if "gateway.ai.cloudflare.com" in api_base: @@ -960,7 +960,7 @@ class Router: ) else: - self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}, {api_key}") + self.print_verbose(f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}") model["async_client"] = openai.AsyncOpenAI( api_key=api_key, base_url=api_base, From 27d7d7ba9cea2c4e95e56646b41ca45be0ed18b6 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 18:11:29 -0800 Subject: [PATCH 123/125] (feat) proxy cli, better description of config yaml param --- litellm/proxy/proxy_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 6c17af95a..7dca11dd4 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -73,7 +73,7 @@ def is_port_in_use(port): @click.option('--request_timeout', default=600, type=int, help='Set timeout in seconds for completion calls') @click.option('--drop_params', is_flag=True, help='Drop any unmapped params') @click.option('--add_function_to_prompt', is_flag=True, help='If function passed but unsupported, pass it as prompt') -@click.option('--config', '-c', default=None, help='Configure Litellm') +@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`') @click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`') @click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`') @click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.') From 48aa00d6c0cb241819793bacff68d88a3441685a Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 18:14:01 -0800 Subject: [PATCH 124/125] (fix) proxy - clean up print statement --- litellm/proxy/proxy_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a90530068..af9e1cb4d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -728,9 +728,9 @@ async def startup_event(): global prisma_client, master_key import json worker_config = json.loads(os.getenv("WORKER_CONFIG")) - print(f"worker_config: {worker_config}") + print_verbose(f"worker_config: {worker_config}") initialize(**worker_config) - print(f"prisma client - {prisma_client}") + print_verbose(f"prisma client - {prisma_client}") if prisma_client: await prisma_client.connect() From 642c62f7b71977f8c27368ccba78f42bee129556 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 5 Dec 2023 18:19:15 -0800 Subject: [PATCH 125/125] (fix) proxy: better debugging when -debug is on --- litellm/proxy/proxy_server.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index af9e1cb4d..6ee98a31e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -427,9 +427,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): ## PRINT YAML FOR CONFIRMING IT WORKS printed_yaml = copy.deepcopy(config) printed_yaml.pop("environment_variables", None) - if "model_list" in printed_yaml: - for model in printed_yaml["model_list"]: - model["litellm_params"].pop("api_key", None) print_verbose(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}")