From 460b48914e66f546b406b414c81fb24ded3e719d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 12 Feb 2024 17:25:57 -0800 Subject: [PATCH] feat(sagemaker.py): initial commit of working sagemaker with aioboto3 --- litellm/llms/sagemaker.py | 141 +++++++++++++++++++++++---------- litellm/main.py | 2 + litellm/tests/test_async_fn.py | 16 ++++ 3 files changed, 117 insertions(+), 42 deletions(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 78aafe7f7c..cd107df211 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -1,4 +1,4 @@ -import os, types +import os, types, traceback from enum import Enum import json import requests @@ -127,6 +127,7 @@ def completion( optional_params=None, litellm_params=None, logger_fn=None, + acompletion: bool = False, ): import boto3 @@ -213,11 +214,19 @@ def completion( ) return response["Body"] - + elif acompletion == True: + _data = {"inputs": prompt, "parameters": inference_params} + return async_completion( + optional_params=optional_params, + encoding=encoding, + model_response=model_response, + model=model, + logging_obj=logging_obj, + data=_data, + ) data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode( "utf-8" ) - ## LOGGING request_str = f""" response = client.invoke_endpoint( @@ -302,45 +311,93 @@ def completion( return model_response -# async def acompletion( -# client: Any, -# model_response: ModelResponse, -# model: str, -# logging_obj: Any, -# data: dict, -# hf_model_name: str, -# ): -# """ -# Use boto3 create_invocation_async endpoint -# """ -# ## LOGGING -# request_str = f""" -# response = client.invoke_endpoint( -# EndpointName={model}, -# ContentType="application/json", -# Body={data}, -# CustomAttributes="accept_eula=true", -# ) -# """ # type: ignore -# logging_obj.pre_call( -# input=data["prompt"], -# api_key="", -# additional_args={ -# "complete_input_dict": data, -# "request_str": request_str, -# "hf_model_name": hf_model_name, -# }, -# ) -# ## COMPLETION CALL -# try: -# response = client.invoke_endpoint( -# EndpointName=model, -# ContentType="application/json", -# Body=data, -# CustomAttributes="accept_eula=true", -# ) -# except Exception as e: -# raise SagemakerError(status_code=500, message=f"{str(e)}") +async def async_completion( + optional_params, + encoding, + model_response: ModelResponse, + model: str, + logging_obj: Any, + data: dict, +): + """ + Use aioboto3 + """ + import aioboto3 + + session = aioboto3.Session() + async with session.client("sagemaker-runtime", region_name="us-west-2") as client: + ## LOGGING + request_str = f""" + response = client.invoke_endpoint( + EndpointName={model}, + ContentType="application/json", + Body={data}, + CustomAttributes="accept_eula=true", + ) + """ # type: ignore + logging_obj.pre_call( + input=data["inputs"], + api_key="", + additional_args={ + "complete_input_dict": data, + "request_str": request_str, + }, + ) + encoded_data = json.dumps(data).encode("utf-8") + try: + response = await client.invoke_endpoint( + EndpointName=model, + ContentType="application/json", + Body=encoded_data, + CustomAttributes="accept_eula=true", + ) + except Exception as e: + raise SagemakerError(status_code=500, message=f"{str(e)}") + response = await response["Body"].read() + response = response.decode("utf8") + ## LOGGING + logging_obj.post_call( + input=data["inputs"], + api_key="", + original_response=response, + additional_args={"complete_input_dict": data}, + ) + ## RESPONSE OBJECT + completion_response = json.loads(response) + try: + completion_response_choices = completion_response[0] + completion_output = "" + if "generation" in completion_response_choices: + completion_output += completion_response_choices["generation"] + elif "generated_text" in completion_response_choices: + completion_output += completion_response_choices["generated_text"] + + # check if the prompt template is part of output, if so - filter it out + if completion_output.startswith(data["inputs"]) and "" in data["inputs"]: + completion_output = completion_output.replace(data["inputs"], "", 1) + + model_response["choices"][0]["message"]["content"] = completion_output + except: + raise SagemakerError( + message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", + status_code=500, + ) + + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(data["inputs"])) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + model_response.usage = usage + return model_response def embedding( diff --git a/litellm/main.py b/litellm/main.py index a7990ecfb4..31dd4324c5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -264,6 +264,7 @@ async def acompletion( or custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat" or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "sagemaker" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1553,6 +1554,7 @@ def completion( logger_fn=logger_fn, encoding=encoding, logging_obj=logging, + acompletion=acompletion, ) if ( "stream" in optional_params and optional_params["stream"] == True diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index fa32ce331c..816ef15093 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -201,6 +201,22 @@ async def test_hf_completion_tgi(): # test_get_cloudflare_response_streaming() +@pytest.mark.asyncio +async def test_completion_sagemaker(): + # litellm.set_verbose=True + try: + response = await acompletion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=[{"content": "Hello, how are you?", "role": "user"}], + ) + # Add any assertions here to check the response + print(response) + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_get_response_streaming(): import asyncio