From a419d59542d4d557e9e35f6791745019a727e148 Mon Sep 17 00:00:00 2001 From: Joel Eriksson Date: Sun, 17 Dec 2023 17:27:47 +0200 Subject: [PATCH 01/53] Fix for issue that occured when proxying to ollama In the text_completion() function, it previously threw an exception at: raw_response = response._hidden_params.get("original_response", None) Due to response being an coroutine object to an ollama_acompletion call, so I added an asyncio.iscoroutine() check for the response and handle it by calling response = asyncio.run(response) I also had to fix atext_completion(), where init_response was an instance of TextCompletionResponse. Since this case was not handled by the if-elif that checks if init_response is a coroutine, a dict or a ModelResponse instance, response was unbound which threw an exception on the "return response" line. Note that a regular pyright based linter detects that response is possibly unbound, and that the same code pattern is used in multiple other places in main.py. I would suggest that you either change these cases: init_response = await loop.run_in_executor(... if isinstance(init_response, ... response = init_response elif asyncio.iscoroutine(init_response): response = await init_response To either just: response = await loop.run_in_executor( if asyncio.iscoroutine(response): response = await response Or at the very least, include an else statement and set response = init_response, so that response is never unbound when the code proceeds. --- litellm/main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index b7e9ccce2..878d0fa5a 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2016,11 +2016,9 @@ async def atext_completion(*args, **kwargs): response = text_completion(*args, **kwargs) else: # Await normally - init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO - response = init_response - elif asyncio.iscoroutine(init_response): - response = await init_response + response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(response): + response = await response else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) @@ -2196,6 +2194,9 @@ def text_completion( response = TextCompletionStreamWrapper(completion_stream=response, model=model) return response + if asyncio.iscoroutine(response): + response = asyncio.run(response) + transformed_logprobs = None # only supported for TGI models try: From e214e6ab47c6ac9f3349ce07f6ba41a367cc5b69 Mon Sep 17 00:00:00 2001 From: Joel Eriksson Date: Sun, 17 Dec 2023 20:23:26 +0200 Subject: [PATCH 02/53] Fix bug when iterating over lines in ollama response async for line in resp.content.iter_any() will return incomplete lines when the lines are long, and that results in an exception being thrown by json.loads() when it tries to parse the incomplete JSON The default behavior of the stream reader for aiohttp response objects is to iterate over lines, so just removing .iter_any() fixes the bug --- litellm/llms/ollama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index f2a9b0df4..e2be1c2d5 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -195,7 +195,7 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj): raise OllamaError(status_code=resp.status, message=text) completion_string = "" - async for line in resp.content.iter_any(): + async for line in resp.content: if line: try: json_chunk = line.decode("utf-8") From d6ed13fa4f41a29a4e5cbb0ef0a87844230f2a3c Mon Sep 17 00:00:00 2001 From: Ankur Garha Date: Mon, 18 Dec 2023 22:56:08 +0100 Subject: [PATCH 03/53] doc: updated langfuse ver 1.14 in pip install cmd --- docs/my-website/docs/observability/langfuse_integration.md | 2 +- docs/my-website/docs/proxy/logging.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/my-website/docs/observability/langfuse_integration.md b/docs/my-website/docs/observability/langfuse_integration.md index 3e5f5603d..d5ccbd085 100644 --- a/docs/my-website/docs/observability/langfuse_integration.md +++ b/docs/my-website/docs/observability/langfuse_integration.md @@ -15,7 +15,7 @@ join our [discord](https://discord.gg/wuPM9dRgDw) ## Pre-Requisites Ensure you have run `pip install langfuse` for this integration ```shell -pip install langfuse litellm +pip install langfuse==1.14.0 litellm ``` ## Quick Start diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index 253c299f6..69544ff37 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -459,7 +459,7 @@ We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this **Step 1** Install langfuse ```shell -pip install langfuse +pip install langfuse==1.14.0 ``` **Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback` From 014abddff660f84bfa371cb4a2b4fafafa869bc0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 18 Dec 2023 15:14:57 -0800 Subject: [PATCH 04/53] fix(requirements.txt): pin all dependencies --- requirements.txt | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3cf315935..cc5f8c492 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,24 +3,24 @@ anyio==4.2.0 # openai + http req. openai>=1.0.0 # openai req. fastapi # server dep pydantic>=2.5 # openai req. -appdirs # server dep -backoff # server dep -pyyaml # server dep -uvicorn # server dep -boto3 # aws bedrock/sagemaker calls -redis # caching -prisma # for db -mangum # for aws lambda functions -google-generativeai # for vertex ai calls +appdirs==1.4.4 # server dep +backoff==2.2.1 # server dep +pyyaml==6.0 # server dep +uvicorn==0.22.0 # server dep +boto3==1.28.58 # aws bedrock/sagemaker calls +redis==4.6.0 # caching +prisma==0.11.0 # for db +mangum==0.17.0 # for aws lambda functions +google-generativeai==0.1.0 # for vertex ai calls traceloop-sdk==0.5.3 # for open telemetry logging langfuse==1.14.0 # for langfuse self-hosted logging ### LITELLM PACKAGE DEPENDENCIES python-dotenv>=0.2.0 # for env tiktoken>=0.4.0 # for calculating usage importlib-metadata>=6.8.0 # for random utils -tokenizers # for calculating usage -click # for proxy cli +tokenizers==0.14.0 # for calculating usage +click==8.1.7 # for proxy cli jinja2==3.1.2 # for prompt templates certifi>=2023.7.22 # [TODO] clean up -aiohttp # for network calls +aiohttp==3.8.4 # for network calls #### \ No newline at end of file From 34509d8dda9db4bd033585a47edb84bd120856bb Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 18 Dec 2023 17:41:41 -0800 Subject: [PATCH 05/53] fix(main.py): return async completion calls --- litellm/llms/openai.py | 54 +++++++++++++++------------ litellm/main.py | 3 +- litellm/tests/test_text_completion.py | 46 ++++++++++++++++------- 3 files changed, 65 insertions(+), 38 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index c923cbf2d..0731bd509 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -284,7 +284,7 @@ class OpenAIChatCompletion(BaseLLM): additional_args={"complete_input_dict": data}, ) return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response) - except Exception as e: + except Exception as e: raise e def streaming(self, @@ -631,24 +631,27 @@ class OpenAITextCompletion(BaseLLM): api_key: str, model: str): async with httpx.AsyncClient() as client: - response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) - response_json = response.json() - if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text) - - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_base": api_base, - }, - ) + try: + response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) + response_json = response.json() + if response.status_code != 200: + raise OpenAIError(status_code=response.status_code, message=response.text) + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=response, + additional_args={ + "headers": headers, + "api_base": api_base, + }, + ) - ## RESPONSE OBJECT - return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response) + ## RESPONSE OBJECT + return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response) + except Exception as e: + raise e def streaming(self, logging_obj, @@ -687,9 +690,12 @@ class OpenAITextCompletion(BaseLLM): method="POST", timeout=litellm.request_timeout ) as response: - if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text) - - streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj) - async for transformed_chunk in streamwrapper: - yield transformed_chunk \ No newline at end of file + try: + if response.status_code != 200: + raise OpenAIError(status_code=response.status_code, message=response.text) + + streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj) + async for transformed_chunk in streamwrapper: + yield transformed_chunk + except Exception as e: + raise e \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index 6e8931191..52d2ae5b6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2205,7 +2205,8 @@ def text_completion( if stream == True or kwargs.get("stream", False) == True: response = TextCompletionStreamWrapper(completion_stream=response, model=model) return response - + if kwargs.get("acompletion", False) == True: + return response transformed_logprobs = None # only supported for TGI models try: diff --git a/litellm/tests/test_text_completion.py b/litellm/tests/test_text_completion.py index 9257a07f3..f75bd2f7f 100644 --- a/litellm/tests/test_text_completion.py +++ b/litellm/tests/test_text_completion.py @@ -169,17 +169,37 @@ def test_text_completion_stream(): # test_text_completion_stream() -async def test_text_completion_async_stream(): - try: - response = await atext_completion( - model="text-completion-openai/text-davinci-003", - prompt="good morning", - stream=True, - max_tokens=10, - ) - async for chunk in response: - print(f"chunk: {chunk}") - except Exception as e: - pytest.fail(f"GOT exception for HF In streaming{e}") +# async def test_text_completion_async_stream(): +# try: +# response = await atext_completion( +# model="text-completion-openai/text-davinci-003", +# prompt="good morning", +# stream=True, +# max_tokens=10, +# ) +# async for chunk in response: +# print(f"chunk: {chunk}") +# except Exception as e: +# pytest.fail(f"GOT exception for HF In streaming{e}") -asyncio.run(test_text_completion_async_stream()) \ No newline at end of file +# asyncio.run(test_text_completion_async_stream()) + +def test_async_text_completion(): + litellm.set_verbose = True + print('test_async_text_completion') + async def test_get_response(): + try: + response = await litellm.atext_completion( + model="gpt-3.5-turbo-instruct", + prompt="good morning", + stream=False, + max_tokens=10 + ) + print(f"response: {response}") + except litellm.Timeout as e: + print(e) + except Exception as e: + print(e) + + asyncio.run(test_get_response()) +test_async_text_completion() \ No newline at end of file From 071283c102e973c4063415f3db0ded466f32664d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 18 Dec 2023 17:50:26 -0800 Subject: [PATCH 06/53] fix(router.py): init deployment_latency_map even if model_list is empty --- litellm/proxy/proxy_config.yaml | 2 +- litellm/router.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 0180d232e..b9f29a584 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -47,7 +47,7 @@ litellm_settings: # setting callback class # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] -general_settings: +# general_settings: environment_variables: # otel: True # OpenTelemetry Logger diff --git a/litellm/router.py b/litellm/router.py index 410d4964e..0276f5a44 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -84,11 +84,11 @@ class Router: self.set_verbose = set_verbose self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2 + self.deployment_latency_map = {} if model_list: model_list = copy.deepcopy(model_list) self.set_model_list(model_list) self.healthy_deployments: List = self.model_list - self.deployment_latency_map = {} for m in model_list: self.deployment_latency_map[m["litellm_params"]["model"]] = 0 From f73c4b494ce05083b1e6804dc8cd1f32d325d271 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 18 Dec 2023 17:57:22 -0800 Subject: [PATCH 07/53] docs(model_management.md): adding docs on how to add model metadata in the config.yaml --- docs/my-website/docs/proxy/model_management.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/my-website/docs/proxy/model_management.md b/docs/my-website/docs/proxy/model_management.md index 0cd4ab829..8160e2aa7 100644 --- a/docs/my-website/docs/proxy/model_management.md +++ b/docs/my-website/docs/proxy/model_management.md @@ -1,6 +1,17 @@ # Model Management Add new models + Get model info without restarting proxy. +## In Config.yaml + +```yaml +model_list: + - model_name: text-davinci-003 + litellm_params: + model: "text-completion-openai/text-davinci-003" + model_info: + metadata: "here's additional metadata on the model" # returned via GET /model/info +``` + ## Get Model Information Retrieve detailed information about each model listed in the `/models` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes. From b82fcd51d71fa84279a2689a4477c282f17d38f2 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 19 Dec 2023 11:32:28 +0530 Subject: [PATCH 08/53] (docs) ollama/llava --- docs/my-website/docs/providers/ollama.md | 39 ++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/docs/my-website/docs/providers/ollama.md b/docs/my-website/docs/providers/ollama.md index f5a1b0d6d..35a3a560c 100644 --- a/docs/my-website/docs/providers/ollama.md +++ b/docs/my-website/docs/providers/ollama.md @@ -77,6 +77,45 @@ Ollama supported models: https://github.com/jmorganca/ollama | Nous-Hermes 13B | `completion(model='ollama/nous-hermes:13b', messages, api_base="http://localhost:11434", stream=True)` | | Wizard Vicuna Uncensored | `completion(model='ollama/wizard-vicuna', messages, api_base="http://localhost:11434", stream=True)` | +## Ollama Vision Models +| Model Name | Function Call | +|------------------|--------------------------------------| +| llava | `completion('ollama/llava', messages)` | + +#### Using Ollama Vision Models + +Call `ollama/llava` in the same input/output format as OpenAI [`gpt-4-vision`](https://docs.litellm.ai/docs/providers/openai#openai-vision-models) + +LiteLLM Supports the following image types passed in `url` +- Base64 encoded svgs + +**Example Request** +```python +import litellm + +response = litellm.completion( + model = "ollama/llava", + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Whats in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC" + } + } + ] + } + ], +) +print(response) +``` + ## LiteLLM/Ollama Docker Image From ce1b0b89bae72c45b09d1737304686049032e512 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 19 Dec 2023 12:55:20 +0530 Subject: [PATCH 09/53] (fix) proxy - health checks support cli model --- litellm/proxy/health_check.py | 7 +++++-- litellm/proxy/proxy_server.py | 11 ++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/health_check.py b/litellm/proxy/health_check.py index f8e56c059..24ace2e94 100644 --- a/litellm/proxy/health_check.py +++ b/litellm/proxy/health_check.py @@ -96,7 +96,7 @@ async def _perform_health_check(model_list: list): -async def perform_health_check(model_list: list, model: Optional[str] = None): +async def perform_health_check(model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None): """ Perform a health check on the system. @@ -104,7 +104,10 @@ async def perform_health_check(model_list: list, model: Optional[str] = None): (bool): True if the health check passes, False otherwise. """ if not model_list: - return [], [] + if cli_model: + model_list = [{"model_name": cli_model, "litellm_params": {"model": cli_model}}] + else: + return [], [] if model is not None: model_list = [x for x in model_list if x["litellm_params"]["model"] == model] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 13cdc2d02..0ebd8242d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1514,9 +1514,18 @@ async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query ``` else, the health checks will be run on models when /health is called. """ - global health_check_results, use_background_health_checks + global health_check_results, use_background_health_checks, user_model if llm_model_list is None: + # if no router set, check if user set a model using litellm --model ollama/llama2 + if user_model is not None: + healthy_endpoints, unhealthy_endpoints = await perform_health_check(model_list=[], cli_model=user_model) + return { + "healthy_endpoints": healthy_endpoints, + "unhealthy_endpoints": unhealthy_endpoints, + "healthy_count": len(healthy_endpoints), + "unhealthy_count": len(unhealthy_endpoints), + } raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"error": "Model list not initialized"}, From 9995229b97ccb70469d68dd91926facdcea67e8d Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 19 Dec 2023 18:48:34 +0530 Subject: [PATCH 10/53] (fix) proxy + ollama - raise exception correctly --- litellm/llms/ollama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index e2be1c2d5..3ac9ec2a8 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -230,3 +230,4 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj): return model_response except Exception as e: traceback.print_exc() + raise e From 84fde01461c06e7e56b9fc95665cd4bb95391765 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 19 Dec 2023 19:11:35 +0530 Subject: [PATCH 11/53] (fix) docs proxy + auth curl request --- docs/my-website/docs/proxy/virtual_keys.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index 93abc789b..181ccf648 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -39,8 +39,8 @@ litellm --config /path/to/config.yaml ```shell curl 'http://0.0.0.0:8000/key/generate' \ ---h 'Authorization: Bearer sk-1234' \ ---d '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m"}' +--header 'Authorization: Bearer sk-1234' \ +--data '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m"}' ``` - `models`: *list or null (optional)* - Specify the models a token has access too. If null, then token has access to all models on server. From dbcff752b36927fb8b6980a6e126b660678787aa Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 19 Dec 2023 19:22:51 +0530 Subject: [PATCH 12/53] (docs) exception mapping - add more details --- docs/my-website/docs/exception_mapping.md | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/my-website/docs/exception_mapping.md b/docs/my-website/docs/exception_mapping.md index 2dcc584dc..6d27f0094 100644 --- a/docs/my-website/docs/exception_mapping.md +++ b/docs/my-website/docs/exception_mapping.md @@ -1,13 +1,18 @@ # Exception Mapping LiteLLM maps exceptions across all providers to their OpenAI counterparts. -- Rate Limit Errors -- Invalid Request Errors -- Authentication Errors -- Timeout Errors `openai.APITimeoutError` -- ServiceUnavailableError -- APIError -- APIConnectionError + +| Status Code | Error Type | +|-------------|--------------------------| +| 400 | BadRequestError | +| 401 | AuthenticationError | +| 403 | PermissionDeniedError | +| 404 | NotFoundError | +| 422 | UnprocessableEntityError | +| 429 | RateLimitError | +| >=500 | InternalServerError | +| N/A | APIConnectionError | + Base case we return APIConnectionError From 8b0fa2322cabd86dd458524a1d81c0257183e3bb Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 19 Dec 2023 19:23:52 +0530 Subject: [PATCH 13/53] (docs) exception mapping azure openai --- docs/my-website/docs/exception_mapping.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/my-website/docs/exception_mapping.md b/docs/my-website/docs/exception_mapping.md index 6d27f0094..6dccd4048 100644 --- a/docs/my-website/docs/exception_mapping.md +++ b/docs/my-website/docs/exception_mapping.md @@ -88,6 +88,7 @@ Base case - we return the original exception. |---------------|----------------------------|---------------------|---------------------|---------------|-------------------------| | Anthropic | ✅ | ✅ | ✅ | ✅ | | | OpenAI | ✅ | ✅ |✅ |✅ |✅| +| Azure OpenAI | ✅ | ✅ |✅ |✅ |✅| | Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | Cohere | ✅ | ✅ | ✅ | ✅ | ✅ | | Huggingface | ✅ | ✅ | ✅ | ✅ | | From 9eb487efb37c2257bc5499e9a1dcc6c4906facfa Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 19 Dec 2023 19:29:05 +0530 Subject: [PATCH 14/53] (docs) exception mapping --- docs/my-website/docs/exception_mapping.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/my-website/docs/exception_mapping.md b/docs/my-website/docs/exception_mapping.md index 6dccd4048..c6c9bb255 100644 --- a/docs/my-website/docs/exception_mapping.md +++ b/docs/my-website/docs/exception_mapping.md @@ -11,6 +11,7 @@ LiteLLM maps exceptions across all providers to their OpenAI counterparts. | 422 | UnprocessableEntityError | | 429 | RateLimitError | | >=500 | InternalServerError | +| N/A | ContextWindowExceededError| | N/A | APIConnectionError | From 5936664a1671fdda2381b82d29892e794676ca0e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 19 Dec 2023 15:00:52 +0000 Subject: [PATCH 15/53] fix(ollama.py): raise async errors --- litellm/utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index 45d5d02f0..90678f45c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1572,11 +1572,14 @@ def client(original_function): def post_call_processing(original_response, model): try: - call_type = original_function.__name__ - if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value: - model_response = original_response['choices'][0]['message']['content'] - ### POST-CALL RULES ### - rules_obj.post_call_rules(input=model_response, model=model) + if original_response is None: + pass + else: + call_type = original_function.__name__ + if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value: + model_response = original_response['choices'][0]['message']['content'] + ### POST-CALL RULES ### + rules_obj.post_call_rules(input=model_response, model=model) except Exception as e: raise e From c5340b87098638f781ce2dfec3349c47e11232b7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 19 Dec 2023 15:25:29 +0000 Subject: [PATCH 16/53] fix(utils.py): vertex ai exception mapping --- litellm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index 90678f45c..98806c101 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4719,7 +4719,7 @@ def exception_type( ) elif "403" in error_str: exception_mapping_worked = True - raise U( + raise BadRequestError( message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", From b8d7cafb8ddd928c3935d169e38f610fbbd7b706 Mon Sep 17 00:00:00 2001 From: navidre Date: Tue, 19 Dec 2023 14:04:20 -0600 Subject: [PATCH 17/53] Sample code to prevent logging API key Sample code in documentation to prevent logging API key. Ideally should be implemented in the codebase if not already --- .../docs/observability/slack_integration.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/observability/slack_integration.md b/docs/my-website/docs/observability/slack_integration.md index 162ceb450..0ca7f6166 100644 --- a/docs/my-website/docs/observability/slack_integration.md +++ b/docs/my-website/docs/observability/slack_integration.md @@ -41,6 +41,18 @@ def send_slack_alert( # get it from https://api.slack.com/messaging/webhooks slack_webhook_url = os.environ['SLACK_WEBHOOK_URL'] # "https://hooks.slack.com/services/<>/<>/<>" + # Remove api_key from kwargs under litellm_params + if kwargs.get('litellm_params'): + kwargs['litellm_params'].pop('api_key', None) + if kwargs['litellm_params'].get('metadata'): + kwargs['litellm_params']['metadata'].pop('deployment', None) + # Remove deployment under metadata + if kwargs.get('metadata'): + kwargs['metadata'].pop('deployment', None) + # Prevent api_key from being logged + if kwargs.get('api_key'): + kwargs.pop('api_key', None) + # Define the text payload, send data available in litellm custom_callbacks text_payload = f"""LiteLLM Logging: kwargs: {str(kwargs)}\n\n, response: {str(completion_response)}\n\n, start time{str(start_time)} end time: {str(end_time)} """ @@ -90,4 +102,4 @@ response = litellm.completion( - [Schedule Demo 👋](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version) - [Community Discord 💭](https://discord.gg/wuPM9dRgDw) - Our numbers 📞 +1 (770) 8783-106 / ‭+1 (412) 618-6238‬ -- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai \ No newline at end of file +- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai From cd34b859df5be3858f16820875490d1965a5bb94 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 20 Dec 2023 05:49:45 +0530 Subject: [PATCH 18/53] (docs) swagger endpoint --- docs/my-website/docs/proxy/quick_start.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/my-website/docs/proxy/quick_start.md b/docs/my-website/docs/proxy/quick_start.md index 977ebfb51..8aaadb677 100644 --- a/docs/my-website/docs/proxy/quick_start.md +++ b/docs/my-website/docs/proxy/quick_start.md @@ -349,6 +349,12 @@ litellm --config your_config.yaml [**More Info**](./configs.md) ## Server Endpoints + +:::note + +You can see Swagger Docs for the server on root http://0.0.0.0:8000 + +::: - POST `/chat/completions` - chat completions endpoint to call 100+ LLMs - POST `/completions` - completions endpoint - POST `/embeddings` - embedding endpoint for Azure, OpenAI, Huggingface endpoints From 8b26e64b5d35c0745cd542bc80593a031a576aff Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 20 Dec 2023 06:02:05 +0530 Subject: [PATCH 19/53] (fix) proxy: add link t swagger docs on startup --- litellm/proxy/proxy_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0ebd8242d..1aaab7d14 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -774,6 +774,7 @@ def initialize( print(f"\033[1;34mLiteLLM: Test your local proxy with: \"litellm --test\" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n") print(f"\033[1;34mLiteLLM: Curl Command Test for your local proxy\n {curl_command} \033[0m\n") print("\033[1;34mDocs: https://docs.litellm.ai/docs/simple_proxy\033[0m\n") + print(f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:8000 \033[0m\n") # for streaming def data_generator(response): print_verbose("inside generator") From 9548334e2f3116a7e8e1fbc661f43799b28d4e6e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 20 Dec 2023 06:27:26 +0530 Subject: [PATCH 20/53] (docs) swagger docs add description --- litellm/proxy/proxy_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1aaab7d14..129e70f1f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -110,7 +110,7 @@ import json import logging from typing import Union -app = FastAPI(docs_url="/", title="LiteLLM API") +app = FastAPI(docs_url="/", title="LiteLLM API", description="Proxy Server to call 100+ LLMs in the OpenAI format") router = APIRouter() origins = ["*"] From aa78415894eb997734345bba24d28a865d59811e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 20 Dec 2023 06:29:36 +0530 Subject: [PATCH 21/53] (docs) swager - add embeddings tag --- litellm/proxy/proxy_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 129e70f1f..154dd5724 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1045,8 +1045,8 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap detail=error_msg ) -@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) -@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) +@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["embeddings"]) +@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["embeddings"]) async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): global proxy_logging_obj try: From 2d15e5384bc1e9c08fe79b5c9a186685abcec256 Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Tue, 19 Dec 2023 22:26:55 -0500 Subject: [PATCH 22/53] Add partial support of vertexai safety settings --- litellm/llms/vertex_ai.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 5457ee40d..f55575227 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -181,6 +181,7 @@ def completion( from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair from vertexai.language_models import TextGenerationModel, CodeGenerationModel from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig + from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types vertexai.init( @@ -193,6 +194,15 @@ def completion( if k not in optional_params: optional_params[k] = v + ## Process safety settings into format expected by vertex AI + if "safety_settings" in optional_params: + safety_settings = optional_params.pop("safety_settings") + if not isinstance(safety_settings, list): + raise ValueError("safety_settings must be a list") + if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): + raise ValueError("safety_settings must be a list of dicts") + safety_settings=[gapic_content_types.SafetySetting(x) for x in safety_settings] + # vertexai does not use an API key, it looks for credentials.json in the environment prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)]) @@ -238,16 +248,16 @@ def completion( if "stream" in optional_params and optional_params["stream"] == True: stream = optional_params.pop("stream") - request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" + request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" ## LOGGING logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), stream=stream) + model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream) optional_params["stream"] = True return model_response - request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params})).text\n" + request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n" ## LOGGING logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) - response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params)) + response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings) completion_response = response_obj.text response_obj = response_obj._raw_response elif mode == "vision": @@ -258,12 +268,13 @@ def completion( content = [prompt] + images if "stream" in optional_params and optional_params["stream"] == True: stream = optional_params.pop("stream") - request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" + request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) model_response = llm_model.generate_content( contents=content, generation_config=GenerationConfig(**optional_params), + safety_settings=safety_settings, stream=True ) optional_params["stream"] = True @@ -276,7 +287,8 @@ def completion( ## LLM Call response = llm_model.generate_content( contents=content, - generation_config=GenerationConfig(**optional_params) + generation_config=GenerationConfig(**optional_params), + safety_settings=safety_settings, ) completion_response = response.text response_obj = response._raw_response From 229b56fc35366f14b257966289a51c5527a47531 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 20 Dec 2023 09:04:56 +0530 Subject: [PATCH 23/53] (docs) swagger - add embedding tag --- litellm/proxy/proxy_server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 154dd5724..944bfe62d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -872,9 +872,9 @@ def model_list(): object="list", ) -@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)]) -@router.post("/completions", dependencies=[Depends(user_api_key_auth)]) -@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)]) +@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]) +@router.post("/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]) +@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]) async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): global user_temperature, user_request_timeout, user_max_tokens, user_api_base try: From c4b7ab6579e419d7c7df7711bff212fc095d3b88 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 20 Dec 2023 09:44:26 +0530 Subject: [PATCH 24/53] (feat) proxy - add metadata for keys --- litellm/proxy/schema.prisma | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index ab4fc5e00..6cfcdb866 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -17,4 +17,5 @@ model LiteLLM_VerificationToken { config Json @default("{}") user_id String? max_parallel_requests Int? + metadata Json @default("{}") } \ No newline at end of file From 7ad21de4417be7dd19467f63954b05b4c2faa8dc Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 20 Dec 2023 09:55:35 +0530 Subject: [PATCH 25/53] (feat) proxy /key/generate add metadata to _types --- litellm/proxy/_types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index cb04f32a5..233c1b642 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -120,6 +120,7 @@ class GenerateKeyRequest(LiteLLMBase): spend: Optional[float] = 0 user_id: Optional[str] = None max_parallel_requests: Optional[int] = None + metadata: Optional[dict] = {} class UpdateKeyRequest(LiteLLMBase): key: str @@ -130,6 +131,7 @@ class UpdateKeyRequest(LiteLLMBase): spend: Optional[float] = None user_id: Optional[str] = None max_parallel_requests: Optional[int] = None + metadata: Optional[dict] = {} class GenerateKeyResponse(LiteLLMBase): key: str @@ -158,6 +160,7 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k user_id: Optional[str] = None max_parallel_requests: Optional[int] = None duration: str = "1h" + metadata: dict = {} class ConfigGeneralSettings(LiteLLMBase): """ From 683a1ee979e36a620ddf7164f561a4bc9c3ebc59 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 20 Dec 2023 09:57:34 +0530 Subject: [PATCH 26/53] (feat) proxy key/generate pass metadata in requests --- litellm/proxy/proxy_server.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 944bfe62d..10b40321e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -616,7 +616,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): router = litellm.Router(**router_params) # type:ignore return router, model_list, general_settings -async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None): +async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None, metadata: Optional[dict] = {}): global prisma_client if prisma_client is None: @@ -653,6 +653,7 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: aliases_json = json.dumps(aliases) config_json = json.dumps(config) + metadata_json = json.dumps(metadata) user_id = user_id or str(uuid.uuid4()) try: # Create a new verification token (you may want to enhance this logic based on your needs) @@ -664,7 +665,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: "config": config_json, "spend": spend, "user_id": user_id, - "max_parallel_requests": max_parallel_requests + "max_parallel_requests": max_parallel_requests, + "metadata": metadata_json } new_verification_token = await prisma_client.insert_data(data=verification_token_data) except Exception as e: @@ -1141,6 +1143,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat - config: Optional[dict] - any key-specific configs, overrides config in config.yaml - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. + - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } Returns: - key: (str) The generated api key From 7d20ea23d171b8ce80eb17c3752fa41721838e40 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 20 Dec 2023 10:07:38 +0530 Subject: [PATCH 27/53] (docs) set openrouter params --- docs/my-website/docs/providers/openrouter.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/providers/openrouter.md b/docs/my-website/docs/providers/openrouter.md index f4888005d..2c55affae 100644 --- a/docs/my-website/docs/providers/openrouter.md +++ b/docs/my-website/docs/providers/openrouter.md @@ -20,7 +20,7 @@ response = completion( ) ``` -### OpenRouter Completion Models +## OpenRouter Completion Models | Model Name | Function Call | Required OS Variables | |---------------------------|-----------------------------------------------------|--------------------------------------------------------------| @@ -35,3 +35,19 @@ response = completion( | openrouter/meta-llama/llama-2-13b-chat | `completion('openrouter/meta-llama/llama-2-13b-chat', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` | | openrouter/meta-llama/llama-2-70b-chat | `completion('openrouter/meta-llama/llama-2-70b-chat', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` | +## Passing OpenRouter Params - transforms, models, route + +Pass `transforms`, `models`, `route`as arguments to `litellm.completion()` + +```python +import os +from litellm import completion +os.environ["OPENROUTER_API_KEY"] = "" + +response = completion( + model="openrouter/google/palm-2-chat-bison", + messages=messages, + transforms = [""], + route= "" + ) +``` \ No newline at end of file From bab8f3350da9183ab98b4b799762233711d26c9c Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 20 Dec 2023 10:09:09 +0530 Subject: [PATCH 28/53] (docs) openrouter --- docs/my-website/docs/providers/openrouter.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/my-website/docs/providers/openrouter.md b/docs/my-website/docs/providers/openrouter.md index 2c55affae..4f966b5c9 100644 --- a/docs/my-website/docs/providers/openrouter.md +++ b/docs/my-website/docs/providers/openrouter.md @@ -1,5 +1,5 @@ # OpenRouter -LiteLLM supports all the text models from [OpenRouter](https://openrouter.ai/docs) +LiteLLM supports all the text / chat / vision models from [OpenRouter](https://openrouter.ai/docs) Open In Colab @@ -22,8 +22,8 @@ response = completion( ## OpenRouter Completion Models -| Model Name | Function Call | Required OS Variables | -|---------------------------|-----------------------------------------------------|--------------------------------------------------------------| +| Model Name | Function Call | +|---------------------------|-----------------------------------------------------| | openrouter/openai/gpt-3.5-turbo | `completion('openrouter/openai/gpt-3.5-turbo', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` | | openrouter/openai/gpt-3.5-turbo-16k | `completion('openrouter/openai/gpt-3.5-turbo-16k', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` | | openrouter/openai/gpt-4 | `completion('openrouter/openai/gpt-4', messages)` | `os.environ['OR_SITE_URL']`,`os.environ['OR_APP_NAME']`,`os.environ['OPENROUTER_API_KEY']` | From f0df28362a52bca6c6b00745931f7dc6cb000846 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 20 Dec 2023 14:59:43 +0530 Subject: [PATCH 29/53] feat(ollama.py): add support for ollama function calling --- litellm/llms/ollama.py | 111 ++++++++++-------- litellm/llms/prompt_templates/factory.py | 4 +- litellm/main.py | 18 +-- litellm/tests/test_hf_prompt_templates.py | 8 ++ litellm/tests/test_ollama_local.py | 130 +++++++++++++++++++++- litellm/utils.py | 14 ++- 6 files changed, 211 insertions(+), 74 deletions(-) diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 3ac9ec2a8..ceab2c7d3 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -1,14 +1,10 @@ import requests, types, time -import json +import json, uuid import traceback from typing import Optional import litellm import httpx, aiohttp, asyncio -try: - from async_generator import async_generator, yield_ # optional dependency - async_generator_imported = True -except ImportError: - async_generator_imported = False # this should not throw an error, it will impact the 'import litellm' statement +from .prompt_templates.factory import prompt_factory, custom_prompt class OllamaError(Exception): def __init__(self, status_code, message): @@ -106,9 +102,8 @@ class OllamaConfig(): and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) and v is not None} - # ollama implementation -def get_ollama_response_stream( +def get_ollama_response( api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?", @@ -129,6 +124,7 @@ def get_ollama_response_stream( if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v + optional_params["stream"] = optional_params.get("stream", False) data = { "model": model, "prompt": prompt, @@ -146,9 +142,41 @@ def get_ollama_response_stream( else: response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj) return response - - else: + elif optional_params.get("stream", False): return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj) + response = requests.post( + url=f"{url}", + json=data, + ) + if response.status_code != 200: + raise OllamaError(status_code=response.status_code, message=response.text) + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + original_response=response.text, + additional_args={ + "headers": None, + "api_base": api_base, + }, + ) + + response_json = response.json() + + ## RESPONSE OBJECT + model_response["choices"][0]["finish_reason"] = "stop" + if optional_params.get("format", "") == "json": + message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}]) + model_response["choices"][0]["message"] = message + else: + model_response["choices"][0]["message"]["content"] = response_json["response"] + model_response["created"] = int(time.time()) + model_response["model"] = "ollama/" + model + prompt_tokens = response_json["prompt_eval_count"] # type: ignore + completion_tokens = response_json["eval_count"] + model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) + return model_response def ollama_completion_stream(url, data, logging_obj): with httpx.stream( @@ -157,13 +185,15 @@ def ollama_completion_stream(url, data, logging_obj): method="POST", timeout=litellm.request_timeout ) as response: - if response.status_code != 200: - raise OllamaError(status_code=response.status_code, message=response.text) - - streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj) - for transformed_chunk in streamwrapper: - yield transformed_chunk - + try: + if response.status_code != 200: + raise OllamaError(status_code=response.status_code, message=response.text) + + streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj) + for transformed_chunk in streamwrapper: + yield transformed_chunk + except Exception as e: + raise e async def ollama_async_streaming(url, data, model_response, encoding, logging_obj): try: @@ -194,38 +224,29 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj): text = await resp.text() raise OllamaError(status_code=resp.status, message=text) - completion_string = "" - async for line in resp.content: - if line: - try: - json_chunk = line.decode("utf-8") - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - j = json.loads(chunk) - if "error" in j: - completion_obj = { - "role": "assistant", - "content": "", - "error": j - } - raise Exception(f"OllamError - {chunk}") - if "response" in j: - completion_obj = { - "role": "assistant", - "content": j["response"], - } - completion_string = completion_string + completion_obj["content"] - except Exception as e: - traceback.print_exc() - + ## LOGGING + logging_obj.post_call( + input=data['prompt'], + api_key="", + original_response=resp.text, + additional_args={ + "headers": None, + "api_base": url, + }, + ) + + response_json = await resp.json() ## RESPONSE OBJECT model_response["choices"][0]["finish_reason"] = "stop" - model_response["choices"][0]["message"]["content"] = completion_string + if data.get("format", "") == "json": + message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}]) + model_response["choices"][0]["message"] = message + else: + model_response["choices"][0]["message"]["content"] = response_json["response"] model_response["created"] = int(time.time()) model_response["model"] = "ollama/" + data['model'] - prompt_tokens = len(encoding.encode(data['prompt'])) # type: ignore - completion_tokens = len(encoding.encode(completion_string)) + prompt_tokens = response_json["prompt_eval_count"] # type: ignore + completion_tokens = response_json["eval_count"] model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) return model_response except Exception as e: diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 4596e2b62..d908231cb 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -348,7 +348,7 @@ def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/r # Function call template def function_call_prompt(messages: list, functions: list): - function_prompt = "The following functions are available to you:" + function_prompt = "Produce JSON OUTPUT ONLY! The following functions are available to you:" for function in functions: function_prompt += f"""\n{function}\n""" @@ -425,6 +425,6 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str return alpaca_pt(messages=messages) else: return hf_chat_template(original_model_name, messages) - except: + except Exception as e: return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2) diff --git a/litellm/main.py b/litellm/main.py index 1e2a8323a..0018844c3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1329,23 +1329,11 @@ def completion( optional_params["images"] = images ## LOGGING - generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding) + generator = ollama.get_ollama_response(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding) if acompletion is True or optional_params.get("stream", False) == True: return generator - else: - response_string = "" - for chunk in generator: - response_string+=chunk['content'] - - ## RESPONSE OBJECT - model_response["choices"][0]["finish_reason"] = "stop" - model_response["choices"][0]["message"]["content"] = response_string - model_response["created"] = int(time.time()) - model_response["model"] = "ollama/" + model - prompt_tokens = len(encoding.encode(prompt)) # type: ignore - completion_tokens = len(encoding.encode(response_string)) - model_response["usage"] = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) - response = model_response + + response = generator elif ( custom_llm_provider == "baseten" or litellm.api_base == "https://app.baseten.co" diff --git a/litellm/tests/test_hf_prompt_templates.py b/litellm/tests/test_hf_prompt_templates.py index 9f1aedcff..c67779f87 100644 --- a/litellm/tests/test_hf_prompt_templates.py +++ b/litellm/tests/test_hf_prompt_templates.py @@ -17,6 +17,14 @@ def test_prompt_formatting(): assert prompt == "[INST] Be a good bot [/INST] [INST] Hello world [/INST]" except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + +def test_prompt_formatting_custom_model(): + try: + prompt = prompt_factory(model="ehartford/dolphin-2.5-mixtral-8x7b", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}], custom_llm_provider="huggingface") + print(f"prompt: {prompt}") + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") +# test_prompt_formatting_custom_model() # def logger_fn(user_model_dict): # return # print(f"user_model_dict: {user_model_dict}") diff --git a/litellm/tests/test_ollama_local.py b/litellm/tests/test_ollama_local.py index 5beae7033..35be97506 100644 --- a/litellm/tests/test_ollama_local.py +++ b/litellm/tests/test_ollama_local.py @@ -16,23 +16,61 @@ # user_message = "respond in 20 words. who are you?" # messages = [{ "content": user_message,"role": "user"}] +# def test_ollama_streaming(): +# try: +# litellm.set_verbose = False +# messages = [ +# {"role": "user", "content": "What is the weather like in Boston?"} +# ] +# functions = [ +# { +# "name": "get_current_weather", +# "description": "Get the current weather in a given location", +# "parameters": { +# "type": "object", +# "properties": { +# "location": { +# "type": "string", +# "description": "The city and state, e.g. San Francisco, CA" +# }, +# "unit": { +# "type": "string", +# "enum": ["celsius", "fahrenheit"] +# } +# }, +# "required": ["location"] +# } +# } +# ] +# response = litellm.completion(model="ollama/mistral", +# messages=messages, +# functions=functions, +# stream=True) +# for chunk in response: +# print(f"CHUNK: {chunk}") +# except Exception as e: +# print(e) + +# test_ollama_streaming() + # async def test_async_ollama_streaming(): # try: -# litellm.set_verbose = True +# litellm.set_verbose = False # response = await litellm.acompletion(model="ollama/mistral-openorca", # messages=[{"role": "user", "content": "Hey, how's it going?"}], # stream=True) # async for chunk in response: -# print(chunk) +# print(f"CHUNK: {chunk}") # except Exception as e: # print(e) -# asyncio.run(test_async_ollama_streaming()) +# # asyncio.run(test_async_ollama_streaming()) # def test_completion_ollama(): # try: +# litellm.set_verbose = True # response = completion( -# model="ollama/llama2", +# model="ollama/mistral", # messages=[{"role": "user", "content": "Hey, how's it going?"}], # max_tokens=200, # request_timeout = 10, @@ -44,7 +82,87 @@ # except Exception as e: # pytest.fail(f"Error occurred: {e}") -# test_completion_ollama() +# # test_completion_ollama() + +# def test_completion_ollama_function_calling(): +# try: +# litellm.set_verbose = True +# messages = [ +# {"role": "user", "content": "What is the weather like in Boston?"} +# ] +# functions = [ +# { +# "name": "get_current_weather", +# "description": "Get the current weather in a given location", +# "parameters": { +# "type": "object", +# "properties": { +# "location": { +# "type": "string", +# "description": "The city and state, e.g. San Francisco, CA" +# }, +# "unit": { +# "type": "string", +# "enum": ["celsius", "fahrenheit"] +# } +# }, +# "required": ["location"] +# } +# } +# ] +# response = completion( +# model="ollama/mistral", +# messages=messages, +# functions=functions, +# max_tokens=200, +# request_timeout = 10, +# ) +# for chunk in response: +# print(chunk) +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") +# # test_completion_ollama_function_calling() + +# async def async_test_completion_ollama_function_calling(): +# try: +# litellm.set_verbose = True +# messages = [ +# {"role": "user", "content": "What is the weather like in Boston?"} +# ] +# functions = [ +# { +# "name": "get_current_weather", +# "description": "Get the current weather in a given location", +# "parameters": { +# "type": "object", +# "properties": { +# "location": { +# "type": "string", +# "description": "The city and state, e.g. San Francisco, CA" +# }, +# "unit": { +# "type": "string", +# "enum": ["celsius", "fahrenheit"] +# } +# }, +# "required": ["location"] +# } +# } +# ] +# response = await litellm.acompletion( +# model="ollama/mistral", +# messages=messages, +# functions=functions, +# max_tokens=200, +# request_timeout = 10, +# ) +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") + +# # asyncio.run(async_test_completion_ollama_function_calling()) + # def test_completion_ollama_with_api_base(): # try: @@ -197,7 +315,7 @@ # ) # print("Response from ollama/llava") # print(response) -# test_ollama_llava() +# # test_ollama_llava() # # PROCESSED CHUNK PRE CHUNK CREATOR diff --git a/litellm/utils.py b/litellm/utils.py index 98806c101..c449a239e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2390,10 +2390,15 @@ def get_optional_params( # use the openai defaults non_default_params = {k: v for k, v in passed_params.items() if (k != "model" and k != "custom_llm_provider" and k in default_params and v != default_params[k])} optional_params = {} ## raise exception if function calling passed in for a provider that doesn't support it - if "functions" in non_default_params or "function_call" in non_default_params: + if "functions" in non_default_params or "function_call" in non_default_params or "tools" in non_default_params: if custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure": - if litellm.add_function_to_prompt: # if user opts to add it to prompt instead - optional_params["functions_unsupported_model"] = non_default_params.pop("functions") + if custom_llm_provider == "ollama": + # ollama actually supports json output + optional_params["format"] = "json" + litellm.add_function_to_prompt = True # so that main.py adds the function call to the prompt + optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions")) + elif litellm.add_function_to_prompt: # if user opts to add it to prompt instead + optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions")) else: raise UnsupportedParamsError(status_code=500, message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.") @@ -5192,9 +5197,6 @@ def exception_type( raise original_exception raise original_exception elif custom_llm_provider == "ollama": - if "no attribute 'async_get_ollama_response_stream" in error_str: - exception_mapping_worked = True - raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'") if isinstance(original_exception, dict): error_str = original_exception.get("error", "") else: From b3962e483fbb4860d7582f22266230a9e46be8c9 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 20 Dec 2023 16:37:21 +0530 Subject: [PATCH 30/53] feat(azure.py): add support for azure image generations endpoint --- litellm/llms/azure.py | 14 +++-- litellm/llms/custom_httpx/azure_dall_e_2.py | 64 +++++++++++++++++++++ litellm/main.py | 3 +- litellm/tests/test_completion.py | 2 +- litellm/tests/test_image_generation.py | 17 ++++-- litellm/utils.py | 1 + 6 files changed, 90 insertions(+), 11 deletions(-) create mode 100644 litellm/llms/custom_httpx/azure_dall_e_2.py diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index e78529490..026f06fb8 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -6,6 +6,7 @@ from typing import Callable, Optional from litellm import OpenAIConfig import litellm, json import httpx +from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport from openai import AzureOpenAI, AsyncAzureOpenAI class AzureOpenAIError(Exception): @@ -464,11 +465,12 @@ class AzureChatCompletion(BaseLLM): raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) def image_generation(self, - prompt: list, + prompt: str, timeout: float, model: Optional[str]=None, api_key: Optional[str] = None, api_base: Optional[str] = None, + api_version: Optional[str] = None, model_response: Optional[litellm.utils.ImageResponse] = None, logging_obj=None, optional_params=None, @@ -477,9 +479,12 @@ class AzureChatCompletion(BaseLLM): ): exception_mapping_worked = False try: - model = model + if model and len(model) > 0: + model = model + else: + model = None data = { - # "model": model, + "model": model, "prompt": prompt, **optional_params } @@ -492,7 +497,8 @@ class AzureChatCompletion(BaseLLM): # return response if client is None: - azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore + client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),) + azure_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_base, http_client=client_session, timeout=timeout, max_retries=max_retries, api_version=api_version) # type: ignore else: azure_client = client diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py new file mode 100644 index 000000000..c5263bd49 --- /dev/null +++ b/litellm/llms/custom_httpx/azure_dall_e_2.py @@ -0,0 +1,64 @@ +import time +import json +import httpx + +class CustomHTTPTransport(httpx.HTTPTransport): + """ + This class was written as a workaround to support dall-e-2 on openai > v1.x + + Refer to this issue for more: https://github.com/openai/openai-python/issues/692 + """ + def handle_request( + self, + request: httpx.Request, + ) -> httpx.Response: + if "images/generations" in request.url.path and request.url.params[ + "api-version" + ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ]: + request.url = request.url.copy_with(path="/openai/images/generations:submit") + response = super().handle_request(request) + operation_location_url = response.headers["operation-location"] + request.url = httpx.URL(operation_location_url) + request.method = "GET" + response = super().handle_request(request) + response.read() + + timeout_secs: int = 120 + start_time = time.time() + while response.json()["status"] not in ["succeeded", "failed"]: + if time.time() - start_time > timeout_secs: + timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} + return httpx.Response( + status_code=400, + headers=response.headers, + content=json.dumps(timeout).encode("utf-8"), + request=request, + ) + + time.sleep(int(response.headers.get("retry-after")) or 10) + response = super().handle_request(request) + response.read() + + if response.json()["status"] == "failed": + error_data = response.json() + return httpx.Response( + status_code=400, + headers=response.headers, + content=json.dumps(error_data).encode("utf-8"), + request=request, + ) + + result = response.json()["result"] + return httpx.Response( + status_code=200, + headers=response.headers, + content=json.dumps(result).encode("utf-8"), + request=request, + ) + return super().handle_request(request) \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index 0018844c3..318ba8ffb 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2307,8 +2307,7 @@ def image_generation(prompt: str, get_secret("AZURE_AD_TOKEN") ) - # model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response) - pass + model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version) elif custom_llm_provider == "openai": model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 1e4155062..2f561e881 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -727,7 +727,7 @@ def test_completion_azure(): except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_azure() +test_completion_azure() def test_azure_openai_ad_token(): # this tests if the azure ad token is set in the request header diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index a265c0f65..d177ec81d 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -4,7 +4,8 @@ import sys, os import traceback from dotenv import load_dotenv - +import logging +logging.basicConfig(level=logging.DEBUG) load_dotenv() import os @@ -18,14 +19,22 @@ def test_image_generation_openai(): litellm.set_verbose = True response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") print(f"response: {response}") + assert len(response.data) > 0 # test_image_generation_openai() -# def test_image_generation_azure(): -# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure") -# print(f"response: {response}") +def test_image_generation_azure(): + response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview") + print(f"response: {response}") + assert len(response.data) > 0 # test_image_generation_azure() +def test_image_generation_azure_dall_e_3(): + litellm.set_verbose = True + response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test", api_version="2023-12-01-preview", api_base=os.getenv("AZURE_SWEDEN_API_BASE"), api_key=os.getenv("AZURE_SWEDEN_API_KEY")) + print(f"response: {response}") + assert len(response.data) > 0 +# test_image_generation_azure_dall_e_3() # @pytest.mark.asyncio # async def test_async_image_generation_openai(): # response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") diff --git a/litellm/utils.py b/litellm/utils.py index c449a239e..b46ba5a9b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1613,6 +1613,7 @@ def client(original_function): try: model = args[0] if len(args) > 0 else kwargs["model"] except: + model = None call_type = original_function.__name__ if call_type != CallTypes.image_generation.value: raise ValueError("model param not passed in.") From f59b9436bedb8fb16d7551df99b645c80c2cc218 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 20 Dec 2023 16:58:15 +0530 Subject: [PATCH 31/53] feat(main.py): add async image generation support --- litellm/tests/test_image_generation.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index d177ec81d..06441e3f4 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -8,7 +8,7 @@ import logging logging.basicConfig(level=logging.DEBUG) load_dotenv() import os - +import asyncio sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path @@ -35,12 +35,15 @@ def test_image_generation_azure_dall_e_3(): print(f"response: {response}") assert len(response.data) > 0 # test_image_generation_azure_dall_e_3() -# @pytest.mark.asyncio -# async def test_async_image_generation_openai(): -# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") -# print(f"response: {response}") +@pytest.mark.asyncio +async def test_async_image_generation_openai(): + response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") + print(f"response: {response}") + assert len(response.data) > 0 -# @pytest.mark.asyncio -# async def test_async_image_generation_azure(): -# response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3") -# print(f"response: {response}") \ No newline at end of file +# asyncio.run(test_async_image_generation_openai()) + +@pytest.mark.asyncio +async def test_async_image_generation_azure(): + response = await litellm.aimage_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test") + print(f"response: {response}") \ No newline at end of file From f355e03515fea52decd3b6d8d61f94aedd68ccda Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 20 Dec 2023 16:58:40 +0530 Subject: [PATCH 32/53] feat(main.py): add async image generation support --- litellm/llms/azure.py | 42 ++++++++++++++++++++++++++++++++--- litellm/llms/openai.py | 37 +++++++++++++++++++++++++++++++ litellm/main.py | 50 +++++++++++++++++++++++++++++++++++++++--- 3 files changed, 123 insertions(+), 6 deletions(-) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 026f06fb8..208c02678 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -464,6 +464,42 @@ class AzureChatCompletion(BaseLLM): import traceback raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) + async def aimage_generation( + self, + data: dict, + model_response: ModelResponse, + azure_client_params: dict, + api_key: str, + input: list, + client=None, + logging_obj=None + ): + response = None + try: + if client is None: + openai_aclient = AsyncAzureOpenAI(**azure_client_params) + else: + openai_aclient = client + response = await openai_aclient.images.generate(**data) + stringified_response = response.model_dump_json() + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + raise e + def image_generation(self, prompt: str, timeout: float, @@ -492,9 +528,9 @@ class AzureChatCompletion(BaseLLM): if not isinstance(max_retries, int): raise AzureOpenAIError(status_code=422, message="max retries must be an int") - # if aembedding == True: - # response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore - # return response + if aimg_generation == True: + response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore + return response if client is None: client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 0731bd509..e6b535295 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -445,6 +445,43 @@ class OpenAIChatCompletion(BaseLLM): import traceback raise OpenAIError(status_code=500, message=traceback.format_exc()) + async def aimage_generation( + self, + prompt: str, + data: dict, + model_response: ModelResponse, + timeout: float, + api_key: Optional[str]=None, + api_base: Optional[str]=None, + client=None, + max_retries=None, + logging_obj=None + ): + response = None + try: + if client is None: + openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) + else: + openai_aclient = client + response = await openai_aclient.images.generate(**data) # type: ignore + stringified_response = response.model_dump_json() + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + original_response=str(e), + ) + raise e + def image_generation(self, model: Optional[str], prompt: str, diff --git a/litellm/main.py b/litellm/main.py index 318ba8ffb..b2ed72f7f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2230,6 +2230,49 @@ def moderation(input: str, api_key: Optional[str]=None): return response ##### Image Generation ####################### +@client +async def aimage_generation(*args, **kwargs): + """ + Asynchronously calls the `image_generation` function with the given arguments and keyword arguments. + + Parameters: + - `args` (tuple): Positional arguments to be passed to the `embedding` function. + - `kwargs` (dict): Keyword arguments to be passed to the `embedding` function. + + Returns: + - `response` (Any): The response returned by the `embedding` function. + """ + loop = asyncio.get_event_loop() + model = args[0] if len(args) > 0 else kwargs["model"] + ### PASS ARGS TO Image Generation ### + kwargs["aimg_generation"] = True + custom_llm_provider = None + try: + # Use a partial function to pass your keyword arguments + func = partial(image_generation, *args, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + return response + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, + ) + @client def image_generation(prompt: str, model: Optional[str]=None, @@ -2251,6 +2294,7 @@ def image_generation(prompt: str, Currently supports just Azure + OpenAI. """ + aimg_generation = kwargs.get("aimg_generation", False) litellm_call_id = kwargs.get("litellm_call_id", None) logger_fn = kwargs.get("logger_fn", None) proxy_server_request = kwargs.get('proxy_server_request', None) @@ -2264,7 +2308,7 @@ def image_generation(prompt: str, model = "dall-e-2" custom_llm_provider = "openai" # default to dall-e-2 on openai openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"] - litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"] + litellm_params = ["metadata", "aimg_generation", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"] default_params = openai_params + litellm_params non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider optional_params = get_optional_params_image_gen(n=n, @@ -2307,9 +2351,9 @@ def image_generation(prompt: str, get_secret("AZURE_AD_TOKEN") ) - model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version) + model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimage_generation) elif custom_llm_provider == "openai": - model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response) + model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimage_generation) return model_response From 4040f60febea4f84321c1b024ecea722c8c78df9 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 20 Dec 2023 17:24:20 +0530 Subject: [PATCH 33/53] feat(router.py): support async image generation on router --- litellm/router.py | 57 ++++++++++++++++++++++++++++++++---- litellm/tests/test_router.py | 35 ++++++++++++++++++++++ 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 0276f5a44..1e2a32263 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -7,7 +7,7 @@ # # Thank you ! We ❤️ you! - Krrish & Ishaan -import copy +import copy, httpx from datetime import datetime from typing import Dict, List, Optional, Union, Literal, Any import random, threading, time, traceback, uuid @@ -18,6 +18,7 @@ import inspect, concurrent from openai import AsyncOpenAI from collections import defaultdict from litellm.router_strategy.least_busy import LeastBusyLoggingHandler +from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport import copy class Router: """ @@ -166,7 +167,7 @@ class Router: self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n") - ### COMPLETION + EMBEDDING FUNCTIONS + ### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS def completion(self, model: str, @@ -260,6 +261,50 @@ class Router: self.fail_calls[model_name] +=1 raise e + async def aimage_generation(self, + prompt: str, + model: str, + **kwargs): + try: + kwargs["model"] = model + kwargs["prompt"] = prompt + kwargs["original_function"] = self._image_generation + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + raise e + + async def _image_generation(self, + prompt: str, + model: str, + **kwargs): + try: + self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}") + deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None)) + kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if k not in kwargs: # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") + self.total_calls[model_name] +=1 + response = await litellm.aimage_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs}) + self.success_calls[model_name] +=1 + return response + except Exception as e: + if model_name is not None: + self.fail_calls[model_name] +=1 + raise e + def text_completion(self, model: str, prompt: str, @@ -1009,14 +1054,16 @@ class Router: azure_endpoint=api_base, api_version=api_version, timeout=timeout, - max_retries=max_retries + max_retries=max_retries, + http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore ) model["client"] = openai.AzureOpenAI( api_key=api_key, azure_endpoint=api_base, api_version=api_version, timeout=timeout, - max_retries=max_retries + max_retries=max_retries, + http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore ) # streaming clients should have diff timeouts model["stream_async_client"] = openai.AsyncAzureOpenAI( @@ -1024,7 +1071,7 @@ class Router: azure_endpoint=api_base, api_version=api_version, timeout=stream_timeout, - max_retries=max_retries + max_retries=max_retries, ) model["stream_client"] = openai.AzureOpenAI( diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 403c8dc2a..435be3ed5 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -423,6 +423,41 @@ def test_function_calling_on_router(): # test_function_calling_on_router() +### IMAGE GENERATION +async def test_aimg_gen_on_router(): + litellm.set_verbose = True + try: + model_list = [ + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "dall-e-3", + }, + }, + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "azure/dall-e-3-test", + "api_version": "2023-12-01-preview", + "api_base": os.getenv("AZURE_SWEDEN_API_BASE"), + "api_key": os.getenv("AZURE_SWEDEN_API_KEY") + } + } + ] + router = Router(model_list=model_list) + response = await router.aimage_generation( + model="dall-e-3", + prompt="A cute baby sea otter" + ) + print(response) + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + +asyncio.run(test_aimg_gen_on_router()) +### + def test_aembedding_on_router(): litellm.set_verbose = True try: From 350389f50108fafba23f1c9fd868700bfe1f4e54 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 20 Dec 2023 17:48:33 +0530 Subject: [PATCH 34/53] fix(utils.py): add support for anyscale function calling --- litellm/tests/test_completion.py | 56 ++++++++++++++++---------------- litellm/utils.py | 6 +++- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 2f561e881..8cecd270a 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -599,34 +599,34 @@ def test_completion_hf_model_no_provider(): # test_completion_hf_model_no_provider() -# def test_completion_openai_azure_with_functions(): -# function1 = [ -# { -# "name": "get_current_weather", -# "description": "Get the current weather in a given location", -# "parameters": { -# "type": "object", -# "properties": { -# "location": { -# "type": "string", -# "description": "The city and state, e.g. San Francisco, CA", -# }, -# "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, -# }, -# "required": ["location"], -# }, -# } -# ] -# try: -# messages = [{"role": "user", "content": "What is the weather like in Boston?"}] -# response = completion( -# model="azure/chatgpt-functioncalling", messages=messages, functions=function1 -# ) -# # Add any assertions here to check the response -# print(response) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") -# test_completion_openai_azure_with_functions() +def test_completion_anyscale_with_functions(): + function1 = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + try: + messages = [{"role": "user", "content": "What is the weather like in Boston?"}] + response = completion( + model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, functions=function1 + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") +test_completion_anyscale_with_functions() def test_completion_azure_key_completion_arg(): # this tests if we can pass api_key to completion, when it's not in the env diff --git a/litellm/utils.py b/litellm/utils.py index b46ba5a9b..53c9dce17 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2398,6 +2398,8 @@ def get_optional_params( # use the openai defaults optional_params["format"] = "json" litellm.add_function_to_prompt = True # so that main.py adds the function call to the prompt optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions")) + elif custom_llm_provider == "anyscale" and model == "mistralai/Mistral-7B-Instruct-v0.1": # anyscale just supports function calling with mistral + pass elif litellm.add_function_to_prompt: # if user opts to add it to prompt instead optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions")) else: @@ -2825,7 +2827,9 @@ def get_optional_params( # use the openai defaults if frequency_penalty: optional_params["frequency_penalty"] = frequency_penalty elif custom_llm_provider == "anyscale": - supported_params = ["temperature", "top_p", "stream", "max_tokens"] + supported_params = ["temperature", "top_p", "stream", "max_tokens", "stop", "frequency_penalty", "presence_penalty"] + if model == "mistralai/Mistral-7B-Instruct-v0.1": + supported_params += ["functions", "function_call", "tools", "tool_choice"] _check_valid_arg(supported_params=supported_params) optional_params = non_default_params if temperature is not None: From 50af89e853dbf13bff37891375eb0f64e629173f Mon Sep 17 00:00:00 2001 From: Reuben Thomas-Davis Date: Wed, 20 Dec 2023 12:40:43 +0000 Subject: [PATCH 35/53] :green_heart: docker build and push on release when a github release is published a docker image is pushed to ghcr avoiding manual workflow dispatch method (but still making it available as a fallback) --- .github/workflows/ghcr_deploy.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ghcr_deploy.yml b/.github/workflows/ghcr_deploy.yml index fc345670c..7e69e8d39 100644 --- a/.github/workflows/ghcr_deploy.yml +++ b/.github/workflows/ghcr_deploy.yml @@ -1,10 +1,12 @@ # -name: Build & Publich to GHCR +name: Build & Publish to GHCR on: workflow_dispatch: inputs: tag: description: "The tag version you want to build" + release: + types: [published] # Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds. env: @@ -19,7 +21,7 @@ jobs: permissions: contents: read packages: write - # + # steps: - name: Checkout repository uses: actions/checkout@v4 @@ -44,5 +46,5 @@ jobs: with: context: . push: true - tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag }} # Add the input tag to the image tags + tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || github.event.release.tag_name }} # if a tag is provided, use that, otherwise use the release tag labels: ${{ steps.meta.outputs.labels }} From fe4427907d7cf2eaf116c6e1e7a14a41e7616e66 Mon Sep 17 00:00:00 2001 From: Reuben Thomas-Davis Date: Wed, 20 Dec 2023 12:41:23 +0000 Subject: [PATCH 36/53] :wastebasket: remove unused docker workflow for clarity --- .github/workflows/docker.yml | 66 ------------------------------------ 1 file changed, 66 deletions(-) delete mode 100644 .github/workflows/docker.yml diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml deleted file mode 100644 index cf6c40b31..000000000 --- a/.github/workflows/docker.yml +++ /dev/null @@ -1,66 +0,0 @@ -name: Build Docker Images -on: - workflow_dispatch: - inputs: - tag: - description: "The tag version you want to build" -jobs: - build: - runs-on: ubuntu-latest - permissions: - contents: read - packages: write - env: - REPO_NAME: ${{ github.repository }} - steps: - - name: Convert repo name to lowercase - run: echo "REPO_NAME=$(echo "$REPO_NAME" | tr '[:upper:]' '[:lower:]')" >> $GITHUB_ENV - - name: Checkout - uses: actions/checkout@v4 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - name: Login to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GHCR_TOKEN }} - logout: false - - name: Extract metadata (tags, labels) for Docker - id: meta - uses: docker/metadata-action@v5 - with: - images: ghcr.io/berriai/litellm - - name: Get tag to build - id: tag - run: | - echo "latest=ghcr.io/${{ env.REPO_NAME }}:latest" >> $GITHUB_OUTPUT - if [[ -z "${{ github.event.inputs.tag }}" ]]; then - echo "versioned=ghcr.io/${{ env.REPO_NAME }}:${{ github.ref_name }}" >> $GITHUB_OUTPUT - else - echo "versioned=ghcr.io/${{ env.REPO_NAME }}:${{ github.event.inputs.tag }}" >> $GITHUB_OUTPUT - fi - - name: Debug Info - run: | - echo "GHCR_TOKEN=${{ secrets.GHCR_TOKEN }}" - echo "REPO_NAME=${{ env.REPO_NAME }}" - echo "ACTOR=${{ github.actor }}" - - name: Build and push container image to registry - uses: docker/build-push-action@v2 - with: - push: true - tags: ghcr.io/${{ env.REPO_NAME }}:${{ github.sha }} - file: ./Dockerfile - - name: Build and release Docker images - uses: docker/build-push-action@v5 - with: - context: . - platforms: linux/amd64 - tags: | - ${{ steps.tag.outputs.latest }} - ${{ steps.tag.outputs.versioned }} - labels: ${{ steps.meta.outputs.labels }} - push: true - From 04bbd0649f37f3991b5bd570a3cf5788af599287 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 20 Dec 2023 19:10:59 +0530 Subject: [PATCH 37/53] fix(router.py): only do sync image gen fallbacks for now The customhttptransport we use for dall-e-2 only works for sync httpx calls, not async. Will need to spend some time writing the async version n --- litellm/llms/custom_httpx/azure_dall_e_2.py | 4 +- litellm/router.py | 49 +++++++++++++++++++-- litellm/tests/test_proxy_custom_logger.py | 2 - litellm/tests/test_router.py | 36 ++++++++++++++- 4 files changed, 82 insertions(+), 9 deletions(-) diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py index c5263bd49..cda84b156 100644 --- a/litellm/llms/custom_httpx/azure_dall_e_2.py +++ b/litellm/llms/custom_httpx/azure_dall_e_2.py @@ -1,6 +1,4 @@ -import time -import json -import httpx +import time, json, httpx, asyncio class CustomHTTPTransport(httpx.HTTPTransport): """ diff --git a/litellm/router.py b/litellm/router.py index 1e2a32263..ddfd6b87c 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -261,6 +261,50 @@ class Router: self.fail_calls[model_name] +=1 raise e + def image_generation(self, + prompt: str, + model: str, + **kwargs): + try: + kwargs["model"] = model + kwargs["prompt"] = prompt + kwargs["original_function"] = self._image_generation + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = self.function_with_fallbacks(**kwargs) + + return response + except Exception as e: + raise e + + def _image_generation(self, + prompt: str, + model: str, + **kwargs): + try: + self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}") + deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None)) + kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if k not in kwargs: # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async") + self.total_calls[model_name] +=1 + response = litellm.image_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs}) + self.success_calls[model_name] +=1 + return response + except Exception as e: + if model_name is not None: + self.fail_calls[model_name] +=1 + raise e + async def aimage_generation(self, prompt: str, model: str, @@ -268,7 +312,7 @@ class Router: try: kwargs["model"] = model kwargs["prompt"] = prompt - kwargs["original_function"] = self._image_generation + kwargs["original_function"] = self._aimage_generation kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) @@ -278,7 +322,7 @@ class Router: except Exception as e: raise e - async def _image_generation(self, + async def _aimage_generation(self, prompt: str, model: str, **kwargs): @@ -1055,7 +1099,6 @@ class Router: api_version=api_version, timeout=timeout, max_retries=max_retries, - http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore ) model["client"] = openai.AzureOpenAI( api_key=api_key, diff --git a/litellm/tests/test_proxy_custom_logger.py b/litellm/tests/test_proxy_custom_logger.py index 6ddc9caac..0a3097af9 100644 --- a/litellm/tests/test_proxy_custom_logger.py +++ b/litellm/tests/test_proxy_custom_logger.py @@ -99,8 +99,6 @@ def test_embedding(client): def test_chat_completion(client): try: # Your test data - - print("initialized proxy") litellm.set_verbose=False from litellm.proxy.utils import get_instance_fn my_custom_logger = get_instance_fn( diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 435be3ed5..d7f929f25 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -455,7 +455,41 @@ async def test_aimg_gen_on_router(): traceback.print_exc() pytest.fail(f"Error occurred: {e}") -asyncio.run(test_aimg_gen_on_router()) +# asyncio.run(test_aimg_gen_on_router()) + +def test_img_gen_on_router(): + litellm.set_verbose = True + try: + model_list = [ + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "dall-e-3", + }, + }, + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "azure/dall-e-3-test", + "api_version": "2023-12-01-preview", + "api_base": os.getenv("AZURE_SWEDEN_API_BASE"), + "api_key": os.getenv("AZURE_SWEDEN_API_KEY") + } + } + ] + router = Router(model_list=model_list) + response = router.image_generation( + model="dall-e-3", + prompt="A cute baby sea otter" + ) + print(response) + assert len(response.data) > 0 + router.reset() + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + +test_img_gen_on_router() ### def test_aembedding_on_router(): From 482b3b5bc3dc6acfd38f741de7ccf338216bad6c Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Wed, 20 Dec 2023 13:12:50 -0500 Subject: [PATCH 38/53] Add a default for safety settings in vertex AI --- litellm/llms/vertex_ai.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index f55575227..7cfc91701 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -195,6 +195,7 @@ def completion( optional_params[k] = v ## Process safety settings into format expected by vertex AI + safety_settings = None if "safety_settings" in optional_params: safety_settings = optional_params.pop("safety_settings") if not isinstance(safety_settings, list): From 6795f0447ac87422844fc83893cc72421c4a7237 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 06:59:13 +0530 Subject: [PATCH 39/53] fix(utils.py): fix non_default_param pop error for ollama --- litellm/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index 53c9dce17..f68fd49b0 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2397,7 +2397,11 @@ def get_optional_params( # use the openai defaults # ollama actually supports json output optional_params["format"] = "json" litellm.add_function_to_prompt = True # so that main.py adds the function call to the prompt - optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions")) + if "tools" in non_default_params: + optional_params["functions_unsupported_model"] = non_default_params.pop("tools") + non_default_params.pop("tool_choice", None) # causes ollama requests to hang + elif "functions" in non_default_params: + optional_params["functions_unsupported_model"] = non_default_params.pop("functions") elif custom_llm_provider == "anyscale" and model == "mistralai/Mistral-7B-Instruct-v0.1": # anyscale just supports function calling with mistral pass elif litellm.add_function_to_prompt: # if user opts to add it to prompt instead From b701a356cc3482c90217000ca056e09ea656a0f4 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 21 Dec 2023 07:22:09 +0530 Subject: [PATCH 40/53] (fix) vertex ai auth file --- litellm/tests/vertex_key.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/litellm/tests/vertex_key.json b/litellm/tests/vertex_key.json index 728fcdb98..bd319ac94 100644 --- a/litellm/tests/vertex_key.json +++ b/litellm/tests/vertex_key.json @@ -1,13 +1,13 @@ { "type": "service_account", - "project_id": "hardy-device-386718", + "project_id": "reliablekeys", "private_key_id": "", "private_key": "", - "client_email": "litellm-vertexai-ci-cd@hardy-device-386718.iam.gserviceaccount.com", - "client_id": "110281020501213430254", + "client_email": "73470430121-compute@developer.gserviceaccount.com", + "client_id": "108560959659377334173", "auth_uri": "https://accounts.google.com/o/oauth2/auth", "token_uri": "https://oauth2.googleapis.com/token", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/litellm-vertexai-ci-cd%40hardy-device-386718.iam.gserviceaccount.com", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/73470430121-compute%40developer.gserviceaccount.com", "universe_domain": "googleapis.com" } \ No newline at end of file From 3e5cfee1f4a932b9cfcee74ad6d0d741229f9b4e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 21 Dec 2023 07:25:22 +0530 Subject: [PATCH 41/53] (ci/cd) run again --- litellm/tests/test_amazing_vertex_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 6506f0a41..344c88c99 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -20,7 +20,7 @@ import tempfile litellm.num_retries = 3 litellm.cache = None user_message = "Write a short poem about the sky" -messages = [{"content": user_message, "role": "user"}] +messages = [{"content": user_message, "role": "user"}] def load_vertex_ai_credentials(): From fbab7371dc95ca3d792f1770aedf31a30bb1460e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 21 Dec 2023 07:36:32 +0530 Subject: [PATCH 42/53] (docs) proxy - virtual keys --- docs/my-website/docs/proxy/virtual_keys.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index 181ccf648..f0f0c8def 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -40,7 +40,8 @@ litellm --config /path/to/config.yaml ```shell curl 'http://0.0.0.0:8000/key/generate' \ --header 'Authorization: Bearer sk-1234' \ ---data '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m"}' +--header 'Content-Type: application/json' \ +--data-raw '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m"}' ``` - `models`: *list or null (optional)* - Specify the models a token has access too. If null, then token has access to all models on server. From f6407aaf7414b4eb7b1bfd29241517306f7e289e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 21 Dec 2023 07:43:53 +0530 Subject: [PATCH 43/53] (docs) add metadata keys/generate --- docs/my-website/docs/proxy/virtual_keys.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index f0f0c8def..91702eddc 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -41,13 +41,15 @@ litellm --config /path/to/config.yaml curl 'http://0.0.0.0:8000/key/generate' \ --header 'Authorization: Bearer sk-1234' \ --header 'Content-Type: application/json' \ ---data-raw '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m"}' +--data-raw '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m","metadata": {"user": "ishaan@berri.ai", "team": "core-infra"}}' ``` - `models`: *list or null (optional)* - Specify the models a token has access too. If null, then token has access to all models on server. - `duration`: *str or null (optional)* Specify the length of time the token is valid for. If null, default is set to 1 hour. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +- `metadata`: *dict or null (optional)* Pass metadata for the created token. If null defaults to {} + Expected response: ```python From 97f64750352a3acab0f7fe548e03901d25335895 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 07:53:56 +0530 Subject: [PATCH 44/53] docs(health.md): add docs on health checks for embedding models --- docs/my-website/docs/proxy/health.md | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/proxy/health.md b/docs/my-website/docs/proxy/health.md index 5dffd7100..d724bdaca 100644 --- a/docs/my-website/docs/proxy/health.md +++ b/docs/my-website/docs/proxy/health.md @@ -59,4 +59,21 @@ $ litellm /path/to/config.yaml 3. Query health endpoint: ``` curl --location 'http://0.0.0.0:8000/health' -``` \ No newline at end of file +``` + +## Embedding Models + +We need some way to know if the model is an embedding model when running checks, if you have this in your config, specifying mode it makes an embedding health check + +```yaml +model_list: + - model_name: azure-embedding-model + litellm_params: + model: azure/azure-embedding-model + api_base: os.environ/AZURE_API_BASE + api_key: os.environ/AZURE_API_KEY + api_version: "2023-07-01-preview" + model_info: + mode: embedding # 👈 ADD THIS +``` + From 87fca1808a90300a80511074ba8c6a3a8c0d65c5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 12:14:34 +0530 Subject: [PATCH 45/53] test(test_amazing_vertex_completion.py): fix project name --- litellm/tests/test_amazing_vertex_completion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 344c88c99..910bd8f6b 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -91,7 +91,7 @@ def test_vertex_ai(): load_vertex_ai_credentials() test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models litellm.set_verbose=False - litellm.vertex_project = "hardy-device-386718" + litellm.vertex_project = "reliablekeys" test_models = random.sample(test_models, 1) test_models += litellm.vertex_language_models # always test gemini-pro @@ -113,7 +113,7 @@ def test_vertex_ai(): def test_vertex_ai_stream(): load_vertex_ai_credentials() litellm.set_verbose=False - litellm.vertex_project = "hardy-device-386718" + litellm.vertex_project = "reliablekeys" import random test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models From 812f9ca1b348917c452a3ff8a0e35cf5355353fb Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 12:23:07 +0530 Subject: [PATCH 46/53] fix(azure.py): correctly raise async exceptions --- litellm/llms/azure.py | 5 ++++- litellm/tests/test_proxy_exception_mapping.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 208c02678..b269afec7 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -262,7 +262,10 @@ class AzureChatCompletion(BaseLLM): exception_mapping_worked = True raise e except Exception as e: - raise AzureOpenAIError(status_code=500, message=str(e)) + if hasattr(e, "status_code"): + raise e + else: + raise AzureOpenAIError(status_code=500, message=str(e)) def streaming(self, logging_obj, diff --git a/litellm/tests/test_proxy_exception_mapping.py b/litellm/tests/test_proxy_exception_mapping.py index 5dcb782c4..c5f99f28c 100644 --- a/litellm/tests/test_proxy_exception_mapping.py +++ b/litellm/tests/test_proxy_exception_mapping.py @@ -68,6 +68,7 @@ def test_chat_completion_exception_azure(client): # make an openai client to call _make_status_error_from_response openai_client = openai.OpenAI(api_key="anything") openai_exception = openai_client._make_status_error_from_response(response=response) + print(openai_exception) assert isinstance(openai_exception, openai.AuthenticationError) except Exception as e: From 81078c400439b650e957f09732feae1c3032a4fe Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 13:03:14 +0530 Subject: [PATCH 47/53] fix(proxy/utils.py): jsonify object before db writes --- litellm/proxy/proxy_server.py | 1 - litellm/proxy/utils.py | 14 +++++++++++--- litellm/tests/langfuse.log | 4 ---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 10b40321e..296d8e576 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1169,7 +1169,6 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest): raise Exception("Not connected to DB!") non_default_values = {k: v for k, v in data_json.items() if v is not None} - print(f"non_default_values: {non_default_values}") response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key}) return {"key": key, **non_default_values} # update based on remaining passed in values diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 938508bcc..52a2fb6aa 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,5 +1,5 @@ from typing import Optional, List, Any, Literal -import os, subprocess, hashlib, importlib, asyncio, copy +import os, subprocess, hashlib, importlib, asyncio, copy, json import litellm, backoff from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache @@ -147,6 +147,14 @@ class PrismaClient: return hashed_token + def jsonify_object(self, data: dict) -> dict: + db_data = copy.deepcopy(data) + + for k, v in db_data.items(): + if isinstance(v, dict): + db_data[k] = json.dumps(v) + return db_data + @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff @@ -193,7 +201,7 @@ class PrismaClient: try: token = data["token"] hashed_token = self.hash_token(token=token) - db_data = copy.deepcopy(data) + db_data = self.jsonify_object(data=data) db_data["token"] = hashed_token new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore @@ -228,7 +236,7 @@ class PrismaClient: if token.startswith("sk-"): token = self.hash_token(token=token) - db_data = copy.deepcopy(data) + db_data = self.jsonify_object(data=data) db_data["token"] = token response = await self.db.litellm_verificationtoken.update( where={ diff --git a/litellm/tests/langfuse.log b/litellm/tests/langfuse.log index 58cdb5267..e69de29bb 100644 --- a/litellm/tests/langfuse.log +++ b/litellm/tests/langfuse.log @@ -1,4 +0,0 @@ -uploading batch of 2 items -successfully uploaded batch of 2 items -uploading batch of 2 items -successfully uploaded batch of 2 items From a4aa645cf64abf8724f87775fa8be946fc20d25a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 13:23:00 +0530 Subject: [PATCH 48/53] =?UTF-8?q?bump:=20version=201.15.1=20=E2=86=92=201.?= =?UTF-8?q?15.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 91c1236ef..070d6b94c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.15.1" +version = "1.15.2" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.15.1" +version = "1.15.2" version_files = [ "pyproject.toml:^version" ] From be68796ebac1a60c725cfdb9e43969e67ec29983 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 14:38:44 +0530 Subject: [PATCH 49/53] fix(router.py): add support for async image generation endpoints --- litellm/llms/azure.py | 25 +++++++-- litellm/llms/custom_httpx/azure_dall_e_2.py | 56 +++++++++++++++++++++ litellm/main.py | 4 +- litellm/router.py | 4 +- litellm/tests/test_router.py | 25 +++++++-- litellm/utils.py | 8 ++- 6 files changed, 109 insertions(+), 13 deletions(-) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index b269afec7..2e75f7b40 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -6,7 +6,7 @@ from typing import Callable, Optional from litellm import OpenAIConfig import litellm, json import httpx -from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport +from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from openai import AzureOpenAI, AsyncAzureOpenAI class AzureOpenAIError(Exception): @@ -480,7 +480,8 @@ class AzureChatCompletion(BaseLLM): response = None try: if client is None: - openai_aclient = AsyncAzureOpenAI(**azure_client_params) + client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) + openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params) else: openai_aclient = client response = await openai_aclient.images.generate(**data) @@ -492,7 +493,7 @@ class AzureChatCompletion(BaseLLM): additional_args={"complete_input_dict": data}, original_response=stringified_response, ) - return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") + return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation") except Exception as e: ## LOGGING logging_obj.post_call( @@ -511,6 +512,7 @@ class AzureChatCompletion(BaseLLM): api_base: Optional[str] = None, api_version: Optional[str] = None, model_response: Optional[litellm.utils.ImageResponse] = None, + azure_ad_token: Optional[str]=None, logging_obj=None, optional_params=None, client=None, @@ -531,13 +533,26 @@ class AzureChatCompletion(BaseLLM): if not isinstance(max_retries, int): raise AzureOpenAIError(status_code=422, message="max retries must be an int") + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "max_retries": max_retries, + "timeout": timeout + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + if aimg_generation == True: - response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore + response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore return response if client is None: client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),) - azure_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_base, http_client=client_session, timeout=timeout, max_retries=max_retries, api_version=api_version) # type: ignore + azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore else: azure_client = client diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py index cda84b156..3bc50dda7 100644 --- a/litellm/llms/custom_httpx/azure_dall_e_2.py +++ b/litellm/llms/custom_httpx/azure_dall_e_2.py @@ -1,5 +1,61 @@ import time, json, httpx, asyncio +class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): + """ + Async implementation of custom http transport + """ + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + if "images/generations" in request.url.path and request.url.params[ + "api-version" + ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ]: + request.url = request.url.copy_with(path="/openai/images/generations:submit") + response = await super().handle_async_request(request) + operation_location_url = response.headers["operation-location"] + request.url = httpx.URL(operation_location_url) + request.method = "GET" + response = await super().handle_async_request(request) + await response.aread() + + timeout_secs: int = 120 + start_time = time.time() + while response.json()["status"] not in ["succeeded", "failed"]: + if time.time() - start_time > timeout_secs: + timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} + return httpx.Response( + status_code=400, + headers=response.headers, + content=json.dumps(timeout).encode("utf-8"), + request=request, + ) + + time.sleep(int(response.headers.get("retry-after")) or 10) + response = await super().handle_async_request(request) + await response.aread() + + if response.json()["status"] == "failed": + error_data = response.json() + return httpx.Response( + status_code=400, + headers=response.headers, + content=json.dumps(error_data).encode("utf-8"), + request=request, + ) + + result = response.json()["result"] + return httpx.Response( + status_code=200, + headers=response.headers, + content=json.dumps(result).encode("utf-8"), + request=request, + ) + return await super().handle_async_request(request) + class CustomHTTPTransport(httpx.HTTPTransport): """ This class was written as a workaround to support dall-e-2 on openai > v1.x diff --git a/litellm/main.py b/litellm/main.py index b2ed72f7f..0e7752e16 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2351,9 +2351,9 @@ def image_generation(prompt: str, get_secret("AZURE_AD_TOKEN") ) - model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimage_generation) + model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimg_generation) elif custom_llm_provider == "openai": - model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimage_generation) + model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimg_generation) return model_response diff --git a/litellm/router.py b/litellm/router.py index ddfd6b87c..4ee067e2f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -18,7 +18,7 @@ import inspect, concurrent from openai import AsyncOpenAI from collections import defaultdict from litellm.router_strategy.least_busy import LeastBusyLoggingHandler -from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport +from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport import copy class Router: """ @@ -525,7 +525,6 @@ class Router: async def async_function_with_retries(self, *args, **kwargs): self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}") - backoff_factor = 1 original_function = kwargs.pop("original_function") fallbacks = kwargs.pop("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks) @@ -1099,6 +1098,7 @@ class Router: api_version=api_version, timeout=timeout, max_retries=max_retries, + http_client=httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) # type: ignore ) model["client"] = openai.AzureOpenAI( api_key=api_key, diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index d7f929f25..b52db394f 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -424,6 +424,7 @@ def test_function_calling_on_router(): # test_function_calling_on_router() ### IMAGE GENERATION +@pytest.mark.asyncio async def test_aimg_gen_on_router(): litellm.set_verbose = True try: @@ -442,14 +443,32 @@ async def test_aimg_gen_on_router(): "api_base": os.getenv("AZURE_SWEDEN_API_BASE"), "api_key": os.getenv("AZURE_SWEDEN_API_KEY") } + }, + { + "model_name": "dall-e-2", + "litellm_params": { + "model": "azure/", + "api_version": "2023-06-01-preview", + "api_base": os.getenv("AZURE_API_BASE"), + "api_key": os.getenv("AZURE_API_KEY") + } } ] router = Router(model_list=model_list) + # response = await router.aimage_generation( + # model="dall-e-3", + # prompt="A cute baby sea otter" + # ) + # print(response) + # assert len(response.data) > 0 + response = await router.aimage_generation( - model="dall-e-3", + model="dall-e-2", prompt="A cute baby sea otter" ) print(response) + assert len(response.data) > 0 + router.reset() except Exception as e: traceback.print_exc() @@ -489,7 +508,7 @@ def test_img_gen_on_router(): traceback.print_exc() pytest.fail(f"Error occurred: {e}") -test_img_gen_on_router() +# test_img_gen_on_router() ### def test_aembedding_on_router(): @@ -625,7 +644,7 @@ async def test_mistral_on_router(): ] ) print(response) -asyncio.run(test_mistral_on_router()) +# asyncio.run(test_mistral_on_router()) def test_openai_completion_on_router(): # [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream diff --git a/litellm/utils.py b/litellm/utils.py index f68fd49b0..3a48958fc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -551,6 +551,8 @@ class ImageResponse(OpenAIObject): data: Optional[list] = None + usage: Optional[dict] = None + def __init__(self, created=None, data=None, response_ms=None): if response_ms: _response_ms = response_ms @@ -565,8 +567,10 @@ class ImageResponse(OpenAIObject): created = created else: created = None - + super().__init__(data=data, created=created) + self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + def __contains__(self, key): # Define custom behavior for the 'in' operator @@ -1668,6 +1672,8 @@ def client(original_function): return result elif "aembedding" in kwargs and kwargs["aembedding"] == True: return result + elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True: + return result ### POST-CALL RULES ### post_call_processing(original_response=result, model=model or None) From 8101ad6801d21f2779ae414cdded80db75065303 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 14:39:29 +0530 Subject: [PATCH 50/53] =?UTF-8?q?bump:=20version=201.15.2=20=E2=86=92=201.?= =?UTF-8?q?15.3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 070d6b94c..62738f3a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.15.2" +version = "1.15.3" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.15.2" +version = "1.15.3" version_files = [ "pyproject.toml:^version" ] From 14115d0d607ea1b6503f5579fb9abccc06d5f59b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 15:39:09 +0530 Subject: [PATCH 51/53] feat(proxy_server.py): add new images/generation endpoint --- litellm/proxy/proxy_server.py | 66 +++++++++++++++++++ .../test_configs/test_config_no_auth.yaml | 17 ++++- litellm/tests/test_proxy_server.py | 25 ++++++- litellm/tests/test_router.py | 12 ++-- 4 files changed, 112 insertions(+), 8 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 296d8e576..8a32f1b4f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1127,6 +1127,72 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen detail=error_msg ) + +@router.post("/v1/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"]) +@router.post("/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["image generation"]) +async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): + global proxy_logging_obj + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + body = await request.body() + data = orjson.loads(body) + + # Include original request and headers in the data + data["proxy_server_request"] = { + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data) # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + data["model"] = ( + general_settings.get("image_generation_model", None) # server default + or user_model # model name passed via cli args + or data["model"] # default passed in http request + ) + if user_model: + data["model"] = user_model + if "metadata" in data: + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["headers"] = dict(request.headers) + else: + data["metadata"] = {"user_api_key": user_api_key_dict.api_key} + data["metadata"]["headers"] = dict(request.headers) + router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] + + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings") + ## ROUTE TO CORRECT ENDPOINT ## + if llm_router is not None and data["model"] in router_model_names: # model in router model list + response = await llm_router.aimage_generation(**data) + elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router + response = await llm_router.aimage_generation(**data, specific_deployment = True) + elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias + response = await llm_router.aimage_generation(**data) # ensure this goes the llm_router, router will do the correct alias mapping + else: + response = await litellm.aimage_generation(**data) + background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL + + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e) + traceback.print_exc() + if isinstance(e, HTTPException): + raise e + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}\n\n{error_traceback}" + try: + status = e.status_code # type: ignore + except: + status = 500 + raise HTTPException( + status_code=status, + detail=error_msg + ) #### KEY MANAGEMENT #### @router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse) diff --git a/litellm/tests/test_configs/test_config_no_auth.yaml b/litellm/tests/test_configs/test_config_no_auth.yaml index 2fd9ef203..76e7a294b 100644 --- a/litellm/tests/test_configs/test_config_no_auth.yaml +++ b/litellm/tests/test_configs/test_config_no_auth.yaml @@ -78,4 +78,19 @@ model_list: model: "bedrock/amazon.titan-embed-text-v1" - model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)" litellm_params: - model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16" \ No newline at end of file + model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16" +- model_name: dall-e-3 + litellm_params: + model: dall-e-3 +- model_name: dall-e-3 + litellm_params: + model: "azure/dall-e-3-test" + api_version: "2023-12-01-preview" + api_base: "os.environ/AZURE_SWEDEN_API_BASE" + api_key: "os.environ/AZURE_SWEDEN_API_KEY" +- model_name: dall-e-2 + litellm_params: + model: "azure/" + api_version: "2023-06-01-preview" + api_base: "os.environ/AZURE_API_BASE" + api_key: "os.environ/AZURE_API_KEY" \ No newline at end of file diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 5e9854f43..b71f5b890 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -101,7 +101,7 @@ def test_chat_completion_azure(client_no_auth): # Run the test # test_chat_completion_azure() - +### EMBEDDING def test_embedding(client_no_auth): global headers from litellm.proxy.proxy_server import user_custom_auth @@ -161,7 +161,30 @@ def test_sagemaker_embedding(client_no_auth): # Run the test # test_embedding() +#### IMAGE GENERATION + +def test_img_gen(client_no_auth): + global headers + from litellm.proxy.proxy_server import user_custom_auth + try: + test_data = { + "model": "dall-e-3", + "prompt": "A cute baby sea otter", + "n": 1, + "size": "1024x1024" + } + + response = client_no_auth.post("/v1/images/generations", json=test_data) + + assert response.status_code == 200 + result = response.json() + print(len(result["data"][0]["url"])) + assert len(result["data"][0]["url"]) > 10 + except Exception as e: + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + +#### ADDITIONAL # @pytest.mark.skip(reason="hitting yaml load issues on circle-ci") def test_add_new_model(client_no_auth): global headers diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index b52db394f..81440c257 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -455,12 +455,12 @@ async def test_aimg_gen_on_router(): } ] router = Router(model_list=model_list) - # response = await router.aimage_generation( - # model="dall-e-3", - # prompt="A cute baby sea otter" - # ) - # print(response) - # assert len(response.data) > 0 + response = await router.aimage_generation( + model="dall-e-3", + prompt="A cute baby sea otter" + ) + print(response) + assert len(response.data) > 0 response = await router.aimage_generation( model="dall-e-2", From 1a32228da536e4eb2fd0e926af783a51e6c90646 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 16:07:20 +0530 Subject: [PATCH 52/53] feat(proxy_server.py): support max budget on proxy --- litellm/proxy/_types.py | 30 +++++++++---------- litellm/proxy/hooks/max_budget_limiter.py | 35 +++++++++++++++++++++++ litellm/proxy/proxy_server.py | 14 +++++++-- litellm/proxy/schema.prisma | 1 + litellm/proxy/utils.py | 4 ++- 5 files changed, 66 insertions(+), 18 deletions(-) create mode 100644 litellm/proxy/hooks/max_budget_limiter.py diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 233c1b642..76d37bddf 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -121,6 +121,7 @@ class GenerateKeyRequest(LiteLLMBase): user_id: Optional[str] = None max_parallel_requests: Optional[int] = None metadata: Optional[dict] = {} + max_budget: Optional[float] = None class UpdateKeyRequest(LiteLLMBase): key: str @@ -132,21 +133,7 @@ class UpdateKeyRequest(LiteLLMBase): user_id: Optional[str] = None max_parallel_requests: Optional[int] = None metadata: Optional[dict] = {} - -class GenerateKeyResponse(LiteLLMBase): - key: str - expires: datetime - user_id: str - - - - -class _DeleteKeyObject(LiteLLMBase): - key: str - -class DeleteKeyRequest(LiteLLMBase): - keys: List[_DeleteKeyObject] - + max_budget: Optional[float] = None class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth """ @@ -161,6 +148,19 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k max_parallel_requests: Optional[int] = None duration: str = "1h" metadata: dict = {} + max_budget: Optional[float] = None + +class GenerateKeyResponse(LiteLLMBase): + key: str + expires: datetime + user_id: str + +class _DeleteKeyObject(LiteLLMBase): + key: str + +class DeleteKeyRequest(LiteLLMBase): + keys: List[_DeleteKeyObject] + class ConfigGeneralSettings(LiteLLMBase): """ diff --git a/litellm/proxy/hooks/max_budget_limiter.py b/litellm/proxy/hooks/max_budget_limiter.py new file mode 100644 index 000000000..b2ffbeea8 --- /dev/null +++ b/litellm/proxy/hooks/max_budget_limiter.py @@ -0,0 +1,35 @@ +from typing import Optional +import litellm +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException + +class MaxBudgetLimiter(CustomLogger): + # Class variables or attributes + def __init__(self): + pass + + def print_verbose(self, print_statement): + if litellm.set_verbose is True: + print(print_statement) # noqa + + + async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str): + self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook") + api_key = user_api_key_dict.api_key + max_budget = user_api_key_dict.max_budget + curr_spend = user_api_key_dict.spend + + if api_key is None: + return + + if max_budget is None: + return + + if curr_spend is None: + return + + # CHECK IF REQUEST ALLOWED + if curr_spend >= max_budget: + raise HTTPException(status_code=429, detail="Max budget limit reached.") \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8a32f1b4f..e6ec64823 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -616,7 +616,16 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): router = litellm.Router(**router_params) # type:ignore return router, model_list, general_settings -async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None, metadata: Optional[dict] = {}): +async def generate_key_helper_fn(duration: Optional[str], + models: list, + aliases: dict, + config: dict, + spend: float, + max_budget: Optional[float]=None, + token: Optional[str]=None, + user_id: Optional[str]=None, + max_parallel_requests: Optional[int]=None, + metadata: Optional[dict] = {},): global prisma_client if prisma_client is None: @@ -666,7 +675,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: "spend": spend, "user_id": user_id, "max_parallel_requests": max_parallel_requests, - "metadata": metadata_json + "metadata": metadata_json, + "max_budget": max_budget } new_verification_token = await prisma_client.insert_data(data=verification_token_data) except Exception as e: diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 6cfcdb866..e4acd13e5 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -18,4 +18,5 @@ model LiteLLM_VerificationToken { user_id String? max_parallel_requests Int? metadata Json @default("{}") + max_budget Float? } \ No newline at end of file diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 52a2fb6aa..3592593d5 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -4,6 +4,7 @@ import litellm, backoff from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler +from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.integrations.custom_logger import CustomLogger def print_verbose(print_statement): if litellm.set_verbose: @@ -23,11 +24,13 @@ class ProxyLogging: self.call_details: dict = {} self.call_details["user_api_key_cache"] = user_api_key_cache self.max_parallel_request_limiter = MaxParallelRequestsHandler() + self.max_budget_limiter = MaxBudgetLimiter() pass def _init_litellm_callbacks(self): print_verbose(f"INITIALIZING LITELLM CALLBACKS!") litellm.callbacks.append(self.max_parallel_request_limiter) + litellm.callbacks.append(self.max_budget_limiter) for callback in litellm.callbacks: if callback not in litellm.input_callback: litellm.input_callback.append(callback) @@ -203,7 +206,6 @@ class PrismaClient: hashed_token = self.hash_token(token=token) db_data = self.jsonify_object(data=data) db_data["token"] = hashed_token - new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore where={ 'token': hashed_token, From d87e59db25921d30f97c19e179bbebebcaeb1da7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Dec 2023 21:20:23 +0530 Subject: [PATCH 53/53] =?UTF-8?q?bump:=20version=201.15.3=20=E2=86=92=201.?= =?UTF-8?q?15.4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 62738f3a5..f6f0054f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.15.3" +version = "1.15.4" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -55,7 +55,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.15.3" +version = "1.15.4" version_files = [ "pyproject.toml:^version" ]