Merge pull request #1952 from BerriAI/litellm_aioboto3_sagemaker

Implements aioboto3 for sagemaker
This commit is contained in:
Krish Dholakia 2024-02-14 21:47:22 -08:00 committed by GitHub
commit 122ad77d56
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 203 additions and 75 deletions

View file

@ -31,6 +31,7 @@ jobs:
pip install "google-generativeai>=0.3.2" pip install "google-generativeai>=0.3.2"
pip install "google-cloud-aiplatform>=1.38.0" pip install "google-cloud-aiplatform>=1.38.0"
pip install "boto3>=1.28.57" pip install "boto3>=1.28.57"
pip install "aioboto3>=12.3.0"
pip install langchain pip install langchain
pip install "langfuse>=2.0.0" pip install "langfuse>=2.0.0"
pip install numpydoc pip install numpydoc
@ -124,6 +125,7 @@ jobs:
pip install "google-generativeai>=0.3.2" pip install "google-generativeai>=0.3.2"
pip install "google-cloud-aiplatform>=1.38.0" pip install "google-cloud-aiplatform>=1.38.0"
pip install "boto3>=1.28.57" pip install "boto3>=1.28.57"
pip install "aioboto3>=12.3.0"
pip install langchain pip install langchain
pip install "langfuse>=2.0.0" pip install "langfuse>=2.0.0"
pip install numpydoc pip install numpydoc

View file

@ -4,6 +4,7 @@ import Image from '@theme/IdealImage';
LiteLLM supports [Microsoft Presidio](https://github.com/microsoft/presidio/) for PII masking. LiteLLM supports [Microsoft Presidio](https://github.com/microsoft/presidio/) for PII masking.
## Quick Start ## Quick Start
### Step 1. Add env ### Step 1. Add env
@ -21,6 +22,7 @@ litellm_settings:
### Step 3. Start proxy ### Step 3. Start proxy
``` ```
litellm --config /path/to/config.yaml litellm --config /path/to/config.yaml
``` ```
@ -52,4 +54,4 @@ litellm_settings:
3. LLM Response: "Hey [PERSON], nice to meet you!" 3. LLM Response: "Hey [PERSON], nice to meet you!"
4. User Response: "Hey Jane Doe, nice to meet you!" 4. User Response: "Hey Jane Doe, nice to meet you!"

View file

@ -1,4 +1,4 @@
import os, types import os, types, traceback
from enum import Enum from enum import Enum
import json import json
import requests import requests
@ -30,8 +30,11 @@ import json
class TokenIterator: class TokenIterator:
def __init__(self, stream): def __init__(self, stream, acompletion: bool = False):
self.byte_iterator = iter(stream) if acompletion == False:
self.byte_iterator = iter(stream)
elif acompletion == True:
self.byte_iterator = stream
self.buffer = io.BytesIO() self.buffer = io.BytesIO()
self.read_pos = 0 self.read_pos = 0
self.end_of_data = False self.end_of_data = False
@ -64,6 +67,34 @@ class TokenIterator:
self.end_of_data = True self.end_of_data = True
return "data: [DONE]" return "data: [DONE]"
def __aiter__(self):
return self
async def __anext__(self):
try:
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
response_obj = {"text": "", "is_finished": False}
self.read_pos += len(line) + 1
full_line = line[:-1].decode("utf-8")
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
if line_data.get("generated_text", None) is not None:
self.end_of_data = True
response_obj["is_finished"] = True
response_obj["text"] = line_data["token"]["text"]
return response_obj
chunk = await self.byte_iterator.__anext__()
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
except StopAsyncIteration as e:
if self.end_of_data == True:
raise e # Re-raise StopIteration
else:
self.end_of_data = True
return "data: [DONE]"
class SagemakerConfig: class SagemakerConfig:
""" """
@ -127,6 +158,7 @@ def completion(
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
acompletion: bool = False,
): ):
import boto3 import boto3
@ -196,15 +228,16 @@ def completion(
data = json.dumps( data = json.dumps(
{"inputs": prompt, "parameters": inference_params, "stream": True} {"inputs": prompt, "parameters": inference_params, "stream": True}
).encode("utf-8") ).encode("utf-8")
## LOGGING if acompletion == True:
request_str = f""" response = async_streaming(
response = client.invoke_endpoint_with_response_stream( optional_params=optional_params,
EndpointName={model}, encoding=encoding,
ContentType="application/json", model_response=model_response,
Body={data}, model=model,
CustomAttributes="accept_eula=true", logging_obj=logging_obj,
) data=data,
""" # type: ignore )
return response
response = client.invoke_endpoint_with_response_stream( response = client.invoke_endpoint_with_response_stream(
EndpointName=model, EndpointName=model,
ContentType="application/json", ContentType="application/json",
@ -213,11 +246,19 @@ def completion(
) )
return response["Body"] return response["Body"]
elif acompletion == True:
_data = {"inputs": prompt, "parameters": inference_params}
return async_completion(
optional_params=optional_params,
encoding=encoding,
model_response=model_response,
model=model,
logging_obj=logging_obj,
data=_data,
)
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode( data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
"utf-8" "utf-8"
) )
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
response = client.invoke_endpoint( response = client.invoke_endpoint(
@ -302,45 +343,122 @@ def completion(
return model_response return model_response
# async def acompletion( async def async_streaming(
# client: Any, optional_params,
# model_response: ModelResponse, encoding,
# model: str, model_response: ModelResponse,
# logging_obj: Any, model: str,
# data: dict, logging_obj: Any,
# hf_model_name: str, data,
# ): ):
# """ """
# Use boto3 create_invocation_async endpoint Use aioboto3
# """ """
# ## LOGGING import aioboto3
# request_str = f"""
# response = client.invoke_endpoint( session = aioboto3.Session()
# EndpointName={model}, async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
# ContentType="application/json", try:
# Body={data}, response = await client.invoke_endpoint_with_response_stream(
# CustomAttributes="accept_eula=true", EndpointName=model,
# ) ContentType="application/json",
# """ # type: ignore Body=data,
# logging_obj.pre_call( CustomAttributes="accept_eula=true",
# input=data["prompt"], )
# api_key="", except Exception as e:
# additional_args={ raise SagemakerError(status_code=500, message=f"{str(e)}")
# "complete_input_dict": data, response = response["Body"]
# "request_str": request_str, async for chunk in response:
# "hf_model_name": hf_model_name, yield chunk
# },
# )
# ## COMPLETION CALL async def async_completion(
# try: optional_params,
# response = client.invoke_endpoint( encoding,
# EndpointName=model, model_response: ModelResponse,
# ContentType="application/json", model: str,
# Body=data, logging_obj: Any,
# CustomAttributes="accept_eula=true", data: dict,
# ) ):
# except Exception as e: """
# raise SagemakerError(status_code=500, message=f"{str(e)}") Use aioboto3
"""
import aioboto3
session = aioboto3.Session()
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)
""" # type: ignore
logging_obj.pre_call(
input=data["inputs"],
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
encoded_data = json.dumps(data).encode("utf-8")
try:
response = await client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=encoded_data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
response = await response["Body"].read()
response = response.decode("utf8")
## LOGGING
logging_obj.post_call(
input=data["inputs"],
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
## RESPONSE OBJECT
completion_response = json.loads(response)
try:
completion_response_choices = completion_response[0]
completion_output = ""
if "generation" in completion_response_choices:
completion_output += completion_response_choices["generation"]
elif "generated_text" in completion_response_choices:
completion_output += completion_response_choices["generated_text"]
# check if the prompt template is part of output, if so - filter it out
if completion_output.startswith(data["inputs"]) and "<s>" in data["inputs"]:
completion_output = completion_output.replace(data["inputs"], "", 1)
model_response["choices"][0]["message"]["content"] = completion_output
except:
raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(data["inputs"]))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding( def embedding(

View file

@ -264,6 +264,7 @@ async def acompletion(
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "ollama_chat" or custom_llm_provider == "ollama_chat"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "sagemaker"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -1553,6 +1554,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion,
) )
if ( if (
"stream" in optional_params and optional_params["stream"] == True "stream" in optional_params and optional_params["stream"] == True
@ -1560,7 +1562,7 @@ def completion(
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
from .llms.sagemaker import TokenIterator from .llms.sagemaker import TokenIterator
tokenIterator = TokenIterator(model_response) tokenIterator = TokenIterator(model_response, acompletion=acompletion)
response = CustomStreamWrapper( response = CustomStreamWrapper(
completion_stream=tokenIterator, completion_stream=tokenIterator,
model=model, model=model,

View file

@ -1907,24 +1907,7 @@ async def async_data_generator(response, user_api_key_dict):
def select_data_generator(response, user_api_key_dict): def select_data_generator(response, user_api_key_dict):
try: return async_data_generator(response=response, user_api_key_dict=user_api_key_dict)
# since boto3 - sagemaker does not support async calls, we should use a sync data_generator
if hasattr(
response, "custom_llm_provider"
) and response.custom_llm_provider in ["sagemaker"]:
return data_generator(
response=response,
)
else:
# default to async_data_generator
return async_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
except:
# worst case - use async_data_generator
return async_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
def get_litellm_model_info(model: dict = {}): def get_litellm_model_info(model: dict = {}):

View file

@ -201,6 +201,22 @@ async def test_hf_completion_tgi():
# test_get_cloudflare_response_streaming() # test_get_cloudflare_response_streaming()
@pytest.mark.asyncio
async def test_completion_sagemaker():
# litellm.set_verbose=True
try:
response = await acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"content": "Hello, how are you?", "role": "user"}],
)
# Add any assertions here to check the response
print(response)
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_get_response_streaming(): def test_get_response_streaming():
import asyncio import asyncio

View file

@ -876,7 +876,6 @@ async def test_sagemaker_streaming_async():
temperature=0.7, temperature=0.7,
stream=True, stream=True,
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
complete_response = "" complete_response = ""
@ -900,6 +899,9 @@ async def test_sagemaker_streaming_async():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
asyncio.run(test_sagemaker_streaming_async())
def test_completion_sagemaker_stream(): def test_completion_sagemaker_stream():
try: try:
response = completion( response = completion(

View file

@ -8705,6 +8705,8 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "ollama" or self.custom_llm_provider == "ollama"
or self.custom_llm_provider == "ollama_chat" or self.custom_llm_provider == "ollama_chat"
or self.custom_llm_provider == "vertex_ai" or self.custom_llm_provider == "vertex_ai"
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
): ):
print_verbose( print_verbose(
f"value of async completion stream: {self.completion_stream}" f"value of async completion stream: {self.completion_stream}"

View file

@ -7,7 +7,7 @@ backoff==2.2.1 # server dep
pyyaml>=6.0.1 # server dep pyyaml>=6.0.1 # server dep
uvicorn==0.22.0 # server dep uvicorn==0.22.0 # server dep
gunicorn==21.2.0 # server dep gunicorn==21.2.0 # server dep
boto3==1.28.58 # aws bedrock/sagemaker calls boto3==1.34.34 # aws bedrock/sagemaker calls
redis==5.0.0 # caching redis==5.0.0 # caching
numpy==1.24.3 # semantic caching numpy==1.24.3 # semantic caching
prisma==0.11.0 # for db prisma==0.11.0 # for db
@ -30,4 +30,5 @@ click==8.1.7 # for proxy cli
jinja2==3.1.3 # for prompt templates jinja2==3.1.3 # for prompt templates
certifi>=2023.7.22 # [TODO] clean up certifi>=2023.7.22 # [TODO] clean up
aiohttp==3.9.0 # for network calls aiohttp==3.9.0 # for network calls
aioboto3==12.3.0 # for async sagemaker calls
#### ####