mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(sagemaker.py): initial commit of working sagemaker with aioboto3
This commit is contained in:
parent
ad7e856a02
commit
460b48914e
3 changed files with 117 additions and 42 deletions
|
@ -1,4 +1,4 @@
|
|||
import os, types
|
||||
import os, types, traceback
|
||||
from enum import Enum
|
||||
import json
|
||||
import requests
|
||||
|
@ -127,6 +127,7 @@ def completion(
|
|||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
):
|
||||
import boto3
|
||||
|
||||
|
@ -213,11 +214,19 @@ def completion(
|
|||
)
|
||||
|
||||
return response["Body"]
|
||||
|
||||
elif acompletion == True:
|
||||
_data = {"inputs": prompt, "parameters": inference_params}
|
||||
return async_completion(
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
data=_data,
|
||||
)
|
||||
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint(
|
||||
|
@ -302,45 +311,93 @@ def completion(
|
|||
return model_response
|
||||
|
||||
|
||||
# async def acompletion(
|
||||
# client: Any,
|
||||
# model_response: ModelResponse,
|
||||
# model: str,
|
||||
# logging_obj: Any,
|
||||
# data: dict,
|
||||
# hf_model_name: str,
|
||||
# ):
|
||||
# """
|
||||
# Use boto3 create_invocation_async endpoint
|
||||
# """
|
||||
# ## LOGGING
|
||||
# request_str = f"""
|
||||
# response = client.invoke_endpoint(
|
||||
# EndpointName={model},
|
||||
# ContentType="application/json",
|
||||
# Body={data},
|
||||
# CustomAttributes="accept_eula=true",
|
||||
# )
|
||||
# """ # type: ignore
|
||||
# logging_obj.pre_call(
|
||||
# input=data["prompt"],
|
||||
# api_key="",
|
||||
# additional_args={
|
||||
# "complete_input_dict": data,
|
||||
# "request_str": request_str,
|
||||
# "hf_model_name": hf_model_name,
|
||||
# },
|
||||
# )
|
||||
# ## COMPLETION CALL
|
||||
# try:
|
||||
# response = client.invoke_endpoint(
|
||||
# EndpointName=model,
|
||||
# ContentType="application/json",
|
||||
# Body=data,
|
||||
# CustomAttributes="accept_eula=true",
|
||||
# )
|
||||
# except Exception as e:
|
||||
# raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||
async def async_completion(
|
||||
optional_params,
|
||||
encoding,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
logging_obj: Any,
|
||||
data: dict,
|
||||
):
|
||||
"""
|
||||
Use aioboto3
|
||||
"""
|
||||
import aioboto3
|
||||
|
||||
session = aioboto3.Session()
|
||||
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName={model},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=data["inputs"],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
encoded_data = json.dumps(data).encode("utf-8")
|
||||
try:
|
||||
response = await client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=encoded_data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
except Exception as e:
|
||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||
response = await response["Body"].read()
|
||||
response = response.decode("utf8")
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data["inputs"],
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
completion_response = json.loads(response)
|
||||
try:
|
||||
completion_response_choices = completion_response[0]
|
||||
completion_output = ""
|
||||
if "generation" in completion_response_choices:
|
||||
completion_output += completion_response_choices["generation"]
|
||||
elif "generated_text" in completion_response_choices:
|
||||
completion_output += completion_response_choices["generated_text"]
|
||||
|
||||
# check if the prompt template is part of output, if so - filter it out
|
||||
if completion_output.startswith(data["inputs"]) and "<s>" in data["inputs"]:
|
||||
completion_output = completion_output.replace(data["inputs"], "", 1)
|
||||
|
||||
model_response["choices"][0]["message"]["content"] = completion_output
|
||||
except:
|
||||
raise SagemakerError(
|
||||
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(encoding.encode(data["inputs"]))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding(
|
||||
|
|
|
@ -264,6 +264,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "ollama_chat"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -1553,6 +1554,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params and optional_params["stream"] == True
|
||||
|
|
|
@ -201,6 +201,22 @@ async def test_hf_completion_tgi():
|
|||
# test_get_cloudflare_response_streaming()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_sagemaker():
|
||||
# litellm.set_verbose=True
|
||||
try:
|
||||
response = await acompletion(
|
||||
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_get_response_streaming():
|
||||
import asyncio
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue