diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py new file mode 100644 index 0000000000..250383b2cf --- /dev/null +++ b/litellm/llms/sagemaker.py @@ -0,0 +1,103 @@ +import os +import json +from enum import Enum +import requests +import time +from typing import Callable +from litellm.utils import ModelResponse + +class SagemakerError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + +def completion( + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params=None, + litellm_params=None, + logger_fn=None, +): + + model = model + prompt = "" + for message in messages: + if "role" in message: + if message["role"] == "user": + prompt += ( + f"{message['content']}" + ) + else: + prompt += ( + f"{message['content']}" + ) + else: + prompt += f"{message['content']}" + data = { + "prompt": prompt, + # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg + **optional_params, + } + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + response = requests.post( + "https://api.ai21.com/studio/v1/" + model + "/complete", headers=headers, data=json.dumps(data) + ) + if "stream" in optional_params and optional_params["stream"] == True: + return response.iter_lines() + else: + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + completion_response = response.json() + if "error" in completion_response: + raise SagemakerError( + message=completion_response["error"], + status_code=response.status_code, + ) + else: + try: + model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["data"]["text"] + except: + raise SagemakerError(message=json.dumps(completion_response), status_code=response.status_code) + + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len( + encoding.encode(prompt) + ) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"]["content"]) + ) + + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + return model_response + +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/main.py b/litellm/main.py index c1c5be6b0e..d1ae00fa0f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -22,6 +22,7 @@ from litellm.utils import ( from .llms import anthropic from .llms import together_ai from .llms import ai21 +from .llms import sagemaker from .llms.huggingface_restapi import HuggingfaceRestAPILLM from .llms.baseten import BasetenLLM from .llms.aleph_alpha import AlephAlphaLLM @@ -680,6 +681,32 @@ def completion( ## RESPONSE OBJECT response = model_response + elif custom_llm_provider == "sagemaker": + # boto3 reads keys from .env + model_response = sagemaker.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=ai21_key, + logging_obj=logging + ) + + # TODO: Add streaming for sagemaker + # if "stream" in optional_params and optional_params["stream"] == True: + # # don't try to access stream object, + # response = CustomStreamWrapper( + # model_response, model, custom_llm_provider="ai21", logging_obj=logging + # ) + # return response + + ## RESPONSE OBJECT + response = model_response + elif custom_llm_provider == "ollama": endpoint = ( litellm.api_base if litellm.api_base is not None else api_base