diff --git a/.circleci/config.yml b/.circleci/config.yml index c1224159a..a24ae1d8e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -31,6 +31,7 @@ jobs: pip install "google-generativeai>=0.3.2" pip install "google-cloud-aiplatform>=1.38.0" pip install "boto3>=1.28.57" + pip install "aioboto3>=12.3.0" pip install langchain pip install "langfuse>=2.0.0" pip install numpydoc @@ -124,6 +125,7 @@ jobs: pip install "google-generativeai>=0.3.2" pip install "google-cloud-aiplatform>=1.38.0" pip install "boto3>=1.28.57" + pip install "aioboto3>=12.3.0" pip install langchain pip install "langfuse>=2.0.0" pip install numpydoc diff --git a/docs/my-website/docs/proxy/pii_masking.md b/docs/my-website/docs/proxy/pii_masking.md index 54fb32bce..7b8400787 100644 --- a/docs/my-website/docs/proxy/pii_masking.md +++ b/docs/my-website/docs/proxy/pii_masking.md @@ -4,6 +4,7 @@ import Image from '@theme/IdealImage'; LiteLLM supports [Microsoft Presidio](https://github.com/microsoft/presidio/) for PII masking. + ## Quick Start ### Step 1. Add env @@ -21,6 +22,7 @@ litellm_settings: ### Step 3. Start proxy + ``` litellm --config /path/to/config.yaml ``` @@ -52,4 +54,4 @@ litellm_settings: 3. LLM Response: "Hey [PERSON], nice to meet you!" -4. User Response: "Hey Jane Doe, nice to meet you!" \ No newline at end of file +4. User Response: "Hey Jane Doe, nice to meet you!" diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 96f84c889..88e486f89 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -1,4 +1,4 @@ -import os, types +import os, types, traceback from enum import Enum import json import requests @@ -30,8 +30,11 @@ import json class TokenIterator: - def __init__(self, stream): - self.byte_iterator = iter(stream) + def __init__(self, stream, acompletion: bool = False): + if acompletion == False: + self.byte_iterator = iter(stream) + elif acompletion == True: + self.byte_iterator = stream self.buffer = io.BytesIO() self.read_pos = 0 self.end_of_data = False @@ -64,6 +67,34 @@ class TokenIterator: self.end_of_data = True return "data: [DONE]" + def __aiter__(self): + return self + + async def __anext__(self): + try: + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + response_obj = {"text": "", "is_finished": False} + self.read_pos += len(line) + 1 + full_line = line[:-1].decode("utf-8") + line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) + if line_data.get("generated_text", None) is not None: + self.end_of_data = True + response_obj["is_finished"] = True + response_obj["text"] = line_data["token"]["text"] + return response_obj + chunk = await self.byte_iterator.__anext__() + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) + except StopAsyncIteration as e: + if self.end_of_data == True: + raise e # Re-raise StopIteration + else: + self.end_of_data = True + return "data: [DONE]" + class SagemakerConfig: """ @@ -127,6 +158,7 @@ def completion( optional_params=None, litellm_params=None, logger_fn=None, + acompletion: bool = False, ): import boto3 @@ -196,15 +228,16 @@ def completion( data = json.dumps( {"inputs": prompt, "parameters": inference_params, "stream": True} ).encode("utf-8") - ## LOGGING - request_str = f""" - response = client.invoke_endpoint_with_response_stream( - EndpointName={model}, - ContentType="application/json", - Body={data}, - CustomAttributes="accept_eula=true", - ) - """ # type: ignore + if acompletion == True: + response = async_streaming( + optional_params=optional_params, + encoding=encoding, + model_response=model_response, + model=model, + logging_obj=logging_obj, + data=data, + ) + return response response = client.invoke_endpoint_with_response_stream( EndpointName=model, ContentType="application/json", @@ -213,11 +246,19 @@ def completion( ) return response["Body"] - + elif acompletion == True: + _data = {"inputs": prompt, "parameters": inference_params} + return async_completion( + optional_params=optional_params, + encoding=encoding, + model_response=model_response, + model=model, + logging_obj=logging_obj, + data=_data, + ) data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode( "utf-8" ) - ## LOGGING request_str = f""" response = client.invoke_endpoint( @@ -302,45 +343,122 @@ def completion( return model_response -# async def acompletion( -# client: Any, -# model_response: ModelResponse, -# model: str, -# logging_obj: Any, -# data: dict, -# hf_model_name: str, -# ): -# """ -# Use boto3 create_invocation_async endpoint -# """ -# ## LOGGING -# request_str = f""" -# response = client.invoke_endpoint( -# EndpointName={model}, -# ContentType="application/json", -# Body={data}, -# CustomAttributes="accept_eula=true", -# ) -# """ # type: ignore -# logging_obj.pre_call( -# input=data["prompt"], -# api_key="", -# additional_args={ -# "complete_input_dict": data, -# "request_str": request_str, -# "hf_model_name": hf_model_name, -# }, -# ) -# ## COMPLETION CALL -# try: -# response = client.invoke_endpoint( -# EndpointName=model, -# ContentType="application/json", -# Body=data, -# CustomAttributes="accept_eula=true", -# ) -# except Exception as e: -# raise SagemakerError(status_code=500, message=f"{str(e)}") +async def async_streaming( + optional_params, + encoding, + model_response: ModelResponse, + model: str, + logging_obj: Any, + data, +): + """ + Use aioboto3 + """ + import aioboto3 + + session = aioboto3.Session() + async with session.client("sagemaker-runtime", region_name="us-west-2") as client: + try: + response = await client.invoke_endpoint_with_response_stream( + EndpointName=model, + ContentType="application/json", + Body=data, + CustomAttributes="accept_eula=true", + ) + except Exception as e: + raise SagemakerError(status_code=500, message=f"{str(e)}") + response = response["Body"] + async for chunk in response: + yield chunk + + +async def async_completion( + optional_params, + encoding, + model_response: ModelResponse, + model: str, + logging_obj: Any, + data: dict, +): + """ + Use aioboto3 + """ + import aioboto3 + + session = aioboto3.Session() + async with session.client("sagemaker-runtime", region_name="us-west-2") as client: + ## LOGGING + request_str = f""" + response = client.invoke_endpoint( + EndpointName={model}, + ContentType="application/json", + Body={data}, + CustomAttributes="accept_eula=true", + ) + """ # type: ignore + logging_obj.pre_call( + input=data["inputs"], + api_key="", + additional_args={ + "complete_input_dict": data, + "request_str": request_str, + }, + ) + encoded_data = json.dumps(data).encode("utf-8") + try: + response = await client.invoke_endpoint( + EndpointName=model, + ContentType="application/json", + Body=encoded_data, + CustomAttributes="accept_eula=true", + ) + except Exception as e: + raise SagemakerError(status_code=500, message=f"{str(e)}") + response = await response["Body"].read() + response = response.decode("utf8") + ## LOGGING + logging_obj.post_call( + input=data["inputs"], + api_key="", + original_response=response, + additional_args={"complete_input_dict": data}, + ) + ## RESPONSE OBJECT + completion_response = json.loads(response) + try: + completion_response_choices = completion_response[0] + completion_output = "" + if "generation" in completion_response_choices: + completion_output += completion_response_choices["generation"] + elif "generated_text" in completion_response_choices: + completion_output += completion_response_choices["generated_text"] + + # check if the prompt template is part of output, if so - filter it out + if completion_output.startswith(data["inputs"]) and "" in data["inputs"]: + completion_output = completion_output.replace(data["inputs"], "", 1) + + model_response["choices"][0]["message"]["content"] = completion_output + except: + raise SagemakerError( + message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", + status_code=500, + ) + + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(data["inputs"])) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + model_response.usage = usage + return model_response def embedding( diff --git a/litellm/main.py b/litellm/main.py index 352ce1882..93ea3c644 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -264,6 +264,7 @@ async def acompletion( or custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat" or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "sagemaker" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1553,6 +1554,7 @@ def completion( logger_fn=logger_fn, encoding=encoding, logging_obj=logging, + acompletion=acompletion, ) if ( "stream" in optional_params and optional_params["stream"] == True @@ -1560,7 +1562,7 @@ def completion( print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") from .llms.sagemaker import TokenIterator - tokenIterator = TokenIterator(model_response) + tokenIterator = TokenIterator(model_response, acompletion=acompletion) response = CustomStreamWrapper( completion_stream=tokenIterator, model=model, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 07e08ac61..2b9c77a9d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1907,24 +1907,7 @@ async def async_data_generator(response, user_api_key_dict): def select_data_generator(response, user_api_key_dict): - try: - # since boto3 - sagemaker does not support async calls, we should use a sync data_generator - if hasattr( - response, "custom_llm_provider" - ) and response.custom_llm_provider in ["sagemaker"]: - return data_generator( - response=response, - ) - else: - # default to async_data_generator - return async_data_generator( - response=response, user_api_key_dict=user_api_key_dict - ) - except: - # worst case - use async_data_generator - return async_data_generator( - response=response, user_api_key_dict=user_api_key_dict - ) + return async_data_generator(response=response, user_api_key_dict=user_api_key_dict) def get_litellm_model_info(model: dict = {}): diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index fa32ce331..816ef1509 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -201,6 +201,22 @@ async def test_hf_completion_tgi(): # test_get_cloudflare_response_streaming() +@pytest.mark.asyncio +async def test_completion_sagemaker(): + # litellm.set_verbose=True + try: + response = await acompletion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=[{"content": "Hello, how are you?", "role": "user"}], + ) + # Add any assertions here to check the response + print(response) + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_get_response_streaming(): import asyncio diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 28a9f9902..a5497b539 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -876,7 +876,6 @@ async def test_sagemaker_streaming_async(): temperature=0.7, stream=True, ) - # Add any assertions here to check the response print(response) complete_response = "" @@ -900,6 +899,9 @@ async def test_sagemaker_streaming_async(): pytest.fail(f"An exception occurred - {str(e)}") +asyncio.run(test_sagemaker_streaming_async()) + + def test_completion_sagemaker_stream(): try: response = completion( diff --git a/litellm/utils.py b/litellm/utils.py index b15be366d..ed03980ce 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8705,6 +8705,8 @@ class CustomStreamWrapper: or self.custom_llm_provider == "ollama" or self.custom_llm_provider == "ollama_chat" or self.custom_llm_provider == "vertex_ai" + or self.custom_llm_provider == "sagemaker" + or self.custom_llm_provider in litellm.openai_compatible_endpoints ): print_verbose( f"value of async completion stream: {self.completion_stream}" diff --git a/requirements.txt b/requirements.txt index a285e29ea..01251468a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ backoff==2.2.1 # server dep pyyaml>=6.0.1 # server dep uvicorn==0.22.0 # server dep gunicorn==21.2.0 # server dep -boto3==1.28.58 # aws bedrock/sagemaker calls +boto3==1.34.34 # aws bedrock/sagemaker calls redis==5.0.0 # caching numpy==1.24.3 # semantic caching prisma==0.11.0 # for db @@ -30,4 +30,5 @@ click==8.1.7 # for proxy cli jinja2==3.1.3 # for prompt templates certifi>=2023.7.22 # [TODO] clean up aiohttp==3.9.0 # for network calls +aioboto3==12.3.0 # for async sagemaker calls #### \ No newline at end of file