forked from phoenix/litellm-mirror
Merge pull request #1952 from BerriAI/litellm_aioboto3_sagemaker
Implements aioboto3 for sagemaker
This commit is contained in:
commit
122ad77d56
9 changed files with 203 additions and 75 deletions
|
@ -31,6 +31,7 @@ jobs:
|
||||||
pip install "google-generativeai>=0.3.2"
|
pip install "google-generativeai>=0.3.2"
|
||||||
pip install "google-cloud-aiplatform>=1.38.0"
|
pip install "google-cloud-aiplatform>=1.38.0"
|
||||||
pip install "boto3>=1.28.57"
|
pip install "boto3>=1.28.57"
|
||||||
|
pip install "aioboto3>=12.3.0"
|
||||||
pip install langchain
|
pip install langchain
|
||||||
pip install "langfuse>=2.0.0"
|
pip install "langfuse>=2.0.0"
|
||||||
pip install numpydoc
|
pip install numpydoc
|
||||||
|
@ -124,6 +125,7 @@ jobs:
|
||||||
pip install "google-generativeai>=0.3.2"
|
pip install "google-generativeai>=0.3.2"
|
||||||
pip install "google-cloud-aiplatform>=1.38.0"
|
pip install "google-cloud-aiplatform>=1.38.0"
|
||||||
pip install "boto3>=1.28.57"
|
pip install "boto3>=1.28.57"
|
||||||
|
pip install "aioboto3>=12.3.0"
|
||||||
pip install langchain
|
pip install langchain
|
||||||
pip install "langfuse>=2.0.0"
|
pip install "langfuse>=2.0.0"
|
||||||
pip install numpydoc
|
pip install numpydoc
|
||||||
|
|
|
@ -4,6 +4,7 @@ import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
LiteLLM supports [Microsoft Presidio](https://github.com/microsoft/presidio/) for PII masking.
|
LiteLLM supports [Microsoft Presidio](https://github.com/microsoft/presidio/) for PII masking.
|
||||||
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
### Step 1. Add env
|
### Step 1. Add env
|
||||||
|
|
||||||
|
@ -21,6 +22,7 @@ litellm_settings:
|
||||||
|
|
||||||
### Step 3. Start proxy
|
### Step 3. Start proxy
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
litellm --config /path/to/config.yaml
|
litellm --config /path/to/config.yaml
|
||||||
```
|
```
|
||||||
|
@ -52,4 +54,4 @@ litellm_settings:
|
||||||
|
|
||||||
3. LLM Response: "Hey [PERSON], nice to meet you!"
|
3. LLM Response: "Hey [PERSON], nice to meet you!"
|
||||||
|
|
||||||
4. User Response: "Hey Jane Doe, nice to meet you!"
|
4. User Response: "Hey Jane Doe, nice to meet you!"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import os, types
|
import os, types, traceback
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
|
@ -30,8 +30,11 @@ import json
|
||||||
|
|
||||||
|
|
||||||
class TokenIterator:
|
class TokenIterator:
|
||||||
def __init__(self, stream):
|
def __init__(self, stream, acompletion: bool = False):
|
||||||
self.byte_iterator = iter(stream)
|
if acompletion == False:
|
||||||
|
self.byte_iterator = iter(stream)
|
||||||
|
elif acompletion == True:
|
||||||
|
self.byte_iterator = stream
|
||||||
self.buffer = io.BytesIO()
|
self.buffer = io.BytesIO()
|
||||||
self.read_pos = 0
|
self.read_pos = 0
|
||||||
self.end_of_data = False
|
self.end_of_data = False
|
||||||
|
@ -64,6 +67,34 @@ class TokenIterator:
|
||||||
self.end_of_data = True
|
self.end_of_data = True
|
||||||
return "data: [DONE]"
|
return "data: [DONE]"
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
self.buffer.seek(self.read_pos)
|
||||||
|
line = self.buffer.readline()
|
||||||
|
if line and line[-1] == ord("\n"):
|
||||||
|
response_obj = {"text": "", "is_finished": False}
|
||||||
|
self.read_pos += len(line) + 1
|
||||||
|
full_line = line[:-1].decode("utf-8")
|
||||||
|
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
|
||||||
|
if line_data.get("generated_text", None) is not None:
|
||||||
|
self.end_of_data = True
|
||||||
|
response_obj["is_finished"] = True
|
||||||
|
response_obj["text"] = line_data["token"]["text"]
|
||||||
|
return response_obj
|
||||||
|
chunk = await self.byte_iterator.__anext__()
|
||||||
|
self.buffer.seek(0, io.SEEK_END)
|
||||||
|
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
||||||
|
except StopAsyncIteration as e:
|
||||||
|
if self.end_of_data == True:
|
||||||
|
raise e # Re-raise StopIteration
|
||||||
|
else:
|
||||||
|
self.end_of_data = True
|
||||||
|
return "data: [DONE]"
|
||||||
|
|
||||||
|
|
||||||
class SagemakerConfig:
|
class SagemakerConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -127,6 +158,7 @@ def completion(
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
acompletion: bool = False,
|
||||||
):
|
):
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
|
@ -196,15 +228,16 @@ def completion(
|
||||||
data = json.dumps(
|
data = json.dumps(
|
||||||
{"inputs": prompt, "parameters": inference_params, "stream": True}
|
{"inputs": prompt, "parameters": inference_params, "stream": True}
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
## LOGGING
|
if acompletion == True:
|
||||||
request_str = f"""
|
response = async_streaming(
|
||||||
response = client.invoke_endpoint_with_response_stream(
|
optional_params=optional_params,
|
||||||
EndpointName={model},
|
encoding=encoding,
|
||||||
ContentType="application/json",
|
model_response=model_response,
|
||||||
Body={data},
|
model=model,
|
||||||
CustomAttributes="accept_eula=true",
|
logging_obj=logging_obj,
|
||||||
)
|
data=data,
|
||||||
""" # type: ignore
|
)
|
||||||
|
return response
|
||||||
response = client.invoke_endpoint_with_response_stream(
|
response = client.invoke_endpoint_with_response_stream(
|
||||||
EndpointName=model,
|
EndpointName=model,
|
||||||
ContentType="application/json",
|
ContentType="application/json",
|
||||||
|
@ -213,11 +246,19 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
return response["Body"]
|
return response["Body"]
|
||||||
|
elif acompletion == True:
|
||||||
|
_data = {"inputs": prompt, "parameters": inference_params}
|
||||||
|
return async_completion(
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
model_response=model_response,
|
||||||
|
model=model,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
data=_data,
|
||||||
|
)
|
||||||
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
|
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
request_str = f"""
|
request_str = f"""
|
||||||
response = client.invoke_endpoint(
|
response = client.invoke_endpoint(
|
||||||
|
@ -302,45 +343,122 @@ def completion(
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
# async def acompletion(
|
async def async_streaming(
|
||||||
# client: Any,
|
optional_params,
|
||||||
# model_response: ModelResponse,
|
encoding,
|
||||||
# model: str,
|
model_response: ModelResponse,
|
||||||
# logging_obj: Any,
|
model: str,
|
||||||
# data: dict,
|
logging_obj: Any,
|
||||||
# hf_model_name: str,
|
data,
|
||||||
# ):
|
):
|
||||||
# """
|
"""
|
||||||
# Use boto3 create_invocation_async endpoint
|
Use aioboto3
|
||||||
# """
|
"""
|
||||||
# ## LOGGING
|
import aioboto3
|
||||||
# request_str = f"""
|
|
||||||
# response = client.invoke_endpoint(
|
session = aioboto3.Session()
|
||||||
# EndpointName={model},
|
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
|
||||||
# ContentType="application/json",
|
try:
|
||||||
# Body={data},
|
response = await client.invoke_endpoint_with_response_stream(
|
||||||
# CustomAttributes="accept_eula=true",
|
EndpointName=model,
|
||||||
# )
|
ContentType="application/json",
|
||||||
# """ # type: ignore
|
Body=data,
|
||||||
# logging_obj.pre_call(
|
CustomAttributes="accept_eula=true",
|
||||||
# input=data["prompt"],
|
)
|
||||||
# api_key="",
|
except Exception as e:
|
||||||
# additional_args={
|
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||||
# "complete_input_dict": data,
|
response = response["Body"]
|
||||||
# "request_str": request_str,
|
async for chunk in response:
|
||||||
# "hf_model_name": hf_model_name,
|
yield chunk
|
||||||
# },
|
|
||||||
# )
|
|
||||||
# ## COMPLETION CALL
|
async def async_completion(
|
||||||
# try:
|
optional_params,
|
||||||
# response = client.invoke_endpoint(
|
encoding,
|
||||||
# EndpointName=model,
|
model_response: ModelResponse,
|
||||||
# ContentType="application/json",
|
model: str,
|
||||||
# Body=data,
|
logging_obj: Any,
|
||||||
# CustomAttributes="accept_eula=true",
|
data: dict,
|
||||||
# )
|
):
|
||||||
# except Exception as e:
|
"""
|
||||||
# raise SagemakerError(status_code=500, message=f"{str(e)}")
|
Use aioboto3
|
||||||
|
"""
|
||||||
|
import aioboto3
|
||||||
|
|
||||||
|
session = aioboto3.Session()
|
||||||
|
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
|
||||||
|
## LOGGING
|
||||||
|
request_str = f"""
|
||||||
|
response = client.invoke_endpoint(
|
||||||
|
EndpointName={model},
|
||||||
|
ContentType="application/json",
|
||||||
|
Body={data},
|
||||||
|
CustomAttributes="accept_eula=true",
|
||||||
|
)
|
||||||
|
""" # type: ignore
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=data["inputs"],
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
encoded_data = json.dumps(data).encode("utf-8")
|
||||||
|
try:
|
||||||
|
response = await client.invoke_endpoint(
|
||||||
|
EndpointName=model,
|
||||||
|
ContentType="application/json",
|
||||||
|
Body=encoded_data,
|
||||||
|
CustomAttributes="accept_eula=true",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||||
|
response = await response["Body"].read()
|
||||||
|
response = response.decode("utf8")
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=data["inputs"],
|
||||||
|
api_key="",
|
||||||
|
original_response=response,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
completion_response = json.loads(response)
|
||||||
|
try:
|
||||||
|
completion_response_choices = completion_response[0]
|
||||||
|
completion_output = ""
|
||||||
|
if "generation" in completion_response_choices:
|
||||||
|
completion_output += completion_response_choices["generation"]
|
||||||
|
elif "generated_text" in completion_response_choices:
|
||||||
|
completion_output += completion_response_choices["generated_text"]
|
||||||
|
|
||||||
|
# check if the prompt template is part of output, if so - filter it out
|
||||||
|
if completion_output.startswith(data["inputs"]) and "<s>" in data["inputs"]:
|
||||||
|
completion_output = completion_output.replace(data["inputs"], "", 1)
|
||||||
|
|
||||||
|
model_response["choices"][0]["message"]["content"] = completion_output
|
||||||
|
except:
|
||||||
|
raise SagemakerError(
|
||||||
|
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||||
|
prompt_tokens = len(encoding.encode(data["inputs"]))
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
model_response.usage = usage
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
|
|
|
@ -264,6 +264,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
or custom_llm_provider == "ollama_chat"
|
or custom_llm_provider == "ollama_chat"
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
|
or custom_llm_provider == "sagemaker"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -1553,6 +1554,7 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
acompletion=acompletion,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params and optional_params["stream"] == True
|
"stream" in optional_params and optional_params["stream"] == True
|
||||||
|
@ -1560,7 +1562,7 @@ def completion(
|
||||||
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
||||||
from .llms.sagemaker import TokenIterator
|
from .llms.sagemaker import TokenIterator
|
||||||
|
|
||||||
tokenIterator = TokenIterator(model_response)
|
tokenIterator = TokenIterator(model_response, acompletion=acompletion)
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
completion_stream=tokenIterator,
|
completion_stream=tokenIterator,
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -1907,24 +1907,7 @@ async def async_data_generator(response, user_api_key_dict):
|
||||||
|
|
||||||
|
|
||||||
def select_data_generator(response, user_api_key_dict):
|
def select_data_generator(response, user_api_key_dict):
|
||||||
try:
|
return async_data_generator(response=response, user_api_key_dict=user_api_key_dict)
|
||||||
# since boto3 - sagemaker does not support async calls, we should use a sync data_generator
|
|
||||||
if hasattr(
|
|
||||||
response, "custom_llm_provider"
|
|
||||||
) and response.custom_llm_provider in ["sagemaker"]:
|
|
||||||
return data_generator(
|
|
||||||
response=response,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# default to async_data_generator
|
|
||||||
return async_data_generator(
|
|
||||||
response=response, user_api_key_dict=user_api_key_dict
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
# worst case - use async_data_generator
|
|
||||||
return async_data_generator(
|
|
||||||
response=response, user_api_key_dict=user_api_key_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_litellm_model_info(model: dict = {}):
|
def get_litellm_model_info(model: dict = {}):
|
||||||
|
|
|
@ -201,6 +201,22 @@ async def test_hf_completion_tgi():
|
||||||
# test_get_cloudflare_response_streaming()
|
# test_get_cloudflare_response_streaming()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_sagemaker():
|
||||||
|
# litellm.set_verbose=True
|
||||||
|
try:
|
||||||
|
response = await acompletion(
|
||||||
|
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||||
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_get_response_streaming():
|
def test_get_response_streaming():
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
|
@ -876,7 +876,6 @@ async def test_sagemaker_streaming_async():
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
|
@ -900,6 +899,9 @@ async def test_sagemaker_streaming_async():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(test_sagemaker_streaming_async())
|
||||||
|
|
||||||
|
|
||||||
def test_completion_sagemaker_stream():
|
def test_completion_sagemaker_stream():
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
|
|
|
@ -8705,6 +8705,8 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "ollama"
|
or self.custom_llm_provider == "ollama"
|
||||||
or self.custom_llm_provider == "ollama_chat"
|
or self.custom_llm_provider == "ollama_chat"
|
||||||
or self.custom_llm_provider == "vertex_ai"
|
or self.custom_llm_provider == "vertex_ai"
|
||||||
|
or self.custom_llm_provider == "sagemaker"
|
||||||
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
):
|
):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"value of async completion stream: {self.completion_stream}"
|
f"value of async completion stream: {self.completion_stream}"
|
||||||
|
|
|
@ -7,7 +7,7 @@ backoff==2.2.1 # server dep
|
||||||
pyyaml>=6.0.1 # server dep
|
pyyaml>=6.0.1 # server dep
|
||||||
uvicorn==0.22.0 # server dep
|
uvicorn==0.22.0 # server dep
|
||||||
gunicorn==21.2.0 # server dep
|
gunicorn==21.2.0 # server dep
|
||||||
boto3==1.28.58 # aws bedrock/sagemaker calls
|
boto3==1.34.34 # aws bedrock/sagemaker calls
|
||||||
redis==5.0.0 # caching
|
redis==5.0.0 # caching
|
||||||
numpy==1.24.3 # semantic caching
|
numpy==1.24.3 # semantic caching
|
||||||
prisma==0.11.0 # for db
|
prisma==0.11.0 # for db
|
||||||
|
@ -30,4 +30,5 @@ click==8.1.7 # for proxy cli
|
||||||
jinja2==3.1.3 # for prompt templates
|
jinja2==3.1.3 # for prompt templates
|
||||||
certifi>=2023.7.22 # [TODO] clean up
|
certifi>=2023.7.22 # [TODO] clean up
|
||||||
aiohttp==3.9.0 # for network calls
|
aiohttp==3.9.0 # for network calls
|
||||||
|
aioboto3==12.3.0 # for async sagemaker calls
|
||||||
####
|
####
|
Loading…
Add table
Add a link
Reference in a new issue