feat(sagemaker.py): initial commit of working sagemaker with aioboto3

This commit is contained in:
Krrish Dholakia 2024-02-12 17:25:57 -08:00
parent ad7e856a02
commit 460b48914e
3 changed files with 117 additions and 42 deletions

View file

@ -1,4 +1,4 @@
import os, types import os, types, traceback
from enum import Enum from enum import Enum
import json import json
import requests import requests
@ -127,6 +127,7 @@ def completion(
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
acompletion: bool = False,
): ):
import boto3 import boto3
@ -213,11 +214,19 @@ def completion(
) )
return response["Body"] return response["Body"]
elif acompletion == True:
_data = {"inputs": prompt, "parameters": inference_params}
return async_completion(
optional_params=optional_params,
encoding=encoding,
model_response=model_response,
model=model,
logging_obj=logging_obj,
data=_data,
)
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode( data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
"utf-8" "utf-8"
) )
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
response = client.invoke_endpoint( response = client.invoke_endpoint(
@ -302,45 +311,93 @@ def completion(
return model_response return model_response
# async def acompletion( async def async_completion(
# client: Any, optional_params,
# model_response: ModelResponse, encoding,
# model: str, model_response: ModelResponse,
# logging_obj: Any, model: str,
# data: dict, logging_obj: Any,
# hf_model_name: str, data: dict,
# ): ):
# """ """
# Use boto3 create_invocation_async endpoint Use aioboto3
# """ """
# ## LOGGING import aioboto3
# request_str = f"""
# response = client.invoke_endpoint( session = aioboto3.Session()
# EndpointName={model}, async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
# ContentType="application/json", ## LOGGING
# Body={data}, request_str = f"""
# CustomAttributes="accept_eula=true", response = client.invoke_endpoint(
# ) EndpointName={model},
# """ # type: ignore ContentType="application/json",
# logging_obj.pre_call( Body={data},
# input=data["prompt"], CustomAttributes="accept_eula=true",
# api_key="", )
# additional_args={ """ # type: ignore
# "complete_input_dict": data, logging_obj.pre_call(
# "request_str": request_str, input=data["inputs"],
# "hf_model_name": hf_model_name, api_key="",
# }, additional_args={
# ) "complete_input_dict": data,
# ## COMPLETION CALL "request_str": request_str,
# try: },
# response = client.invoke_endpoint( )
# EndpointName=model, encoded_data = json.dumps(data).encode("utf-8")
# ContentType="application/json", try:
# Body=data, response = await client.invoke_endpoint(
# CustomAttributes="accept_eula=true", EndpointName=model,
# ) ContentType="application/json",
# except Exception as e: Body=encoded_data,
# raise SagemakerError(status_code=500, message=f"{str(e)}") CustomAttributes="accept_eula=true",
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
response = await response["Body"].read()
response = response.decode("utf8")
## LOGGING
logging_obj.post_call(
input=data["inputs"],
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
## RESPONSE OBJECT
completion_response = json.loads(response)
try:
completion_response_choices = completion_response[0]
completion_output = ""
if "generation" in completion_response_choices:
completion_output += completion_response_choices["generation"]
elif "generated_text" in completion_response_choices:
completion_output += completion_response_choices["generated_text"]
# check if the prompt template is part of output, if so - filter it out
if completion_output.startswith(data["inputs"]) and "<s>" in data["inputs"]:
completion_output = completion_output.replace(data["inputs"], "", 1)
model_response["choices"][0]["message"]["content"] = completion_output
except:
raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(data["inputs"]))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding( def embedding(

View file

@ -264,6 +264,7 @@ async def acompletion(
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "ollama_chat" or custom_llm_provider == "ollama_chat"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "sagemaker"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -1553,6 +1554,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion,
) )
if ( if (
"stream" in optional_params and optional_params["stream"] == True "stream" in optional_params and optional_params["stream"] == True

View file

@ -201,6 +201,22 @@ async def test_hf_completion_tgi():
# test_get_cloudflare_response_streaming() # test_get_cloudflare_response_streaming()
@pytest.mark.asyncio
async def test_completion_sagemaker():
# litellm.set_verbose=True
try:
response = await acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"content": "Hello, how are you?", "role": "user"}],
)
# Add any assertions here to check the response
print(response)
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_get_response_streaming(): def test_get_response_streaming():
import asyncio import asyncio