From 3f116b25a9ac3842708313869434339566707403 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 23 Aug 2024 10:31:35 -0700 Subject: [PATCH] feat(sagemaker.py): add sagemaker messages api support Closes https://github.com/BerriAI/litellm/issues/2641 Closes https://github.com/BerriAI/litellm/pull/5178 --- litellm/__init__.py | 1 + litellm/llms/databricks.py | 33 ++++++++++++++------- litellm/llms/sagemaker.py | 51 +++++++++++++++++++++++++++++---- litellm/main.py | 10 +++++-- litellm/tests/test_sagemaker.py | 34 ++++++++++++++++++++++ litellm/utils.py | 1 + 6 files changed, 112 insertions(+), 18 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 850865a36..92624edc4 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -669,6 +669,7 @@ provider_list: List = [ "azure_text", "azure_ai", "sagemaker", + "sagemaker_chat", "bedrock", "vllm", "nlp_cloud", diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py index 0c5509a71..bd529046a 100644 --- a/litellm/llms/databricks.py +++ b/litellm/llms/databricks.py @@ -235,23 +235,28 @@ class DatabricksChatCompletion(BaseLLM): api_base: Optional[str], endpoint_type: Literal["chat_completions", "embeddings"], custom_endpoint: Optional[bool], + headers: Optional[dict], ) -> Tuple[str, dict]: - if api_key is None: + if api_key is None and headers is None: raise DatabricksError( status_code=400, - message="Missing Databricks API Key - A call is being made to Databricks but no key is set either in the environment variables (DATABRICKS_API_KEY) or via params", + message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", ) if api_base is None: raise DatabricksError( status_code=400, - message="Missing Databricks API Base - A call is being made to Databricks but no api base is set either in the environment variables (DATABRICKS_API_BASE) or via params", + message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", ) - headers = { - "Authorization": "Bearer {}".format(api_key), - "Content-Type": "application/json", - } + if headers is None: + headers = { + "Authorization": "Bearer {}".format(api_key), + "Content-Type": "application/json", + } + else: + if api_key is not None: + headers.update({"Authorization": "Bearer {}".format(api_key)}) if endpoint_type == "chat_completions" and custom_endpoint is not True: api_base = "{}/chat/completions".format(api_base) @@ -356,23 +361,27 @@ class DatabricksChatCompletion(BaseLLM): model_response: ModelResponse, print_verbose: Callable, encoding, - api_key, + api_key: Optional[str], logging_obj, optional_params: dict, acompletion=None, litellm_params=None, logger_fn=None, - headers={}, + headers: Optional[dict] = None, timeout: Optional[Union[float, httpx.Timeout]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + custom_endpoint: Optional[bool] = None, ): - custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None) + custom_endpoint = custom_endpoint or optional_params.pop( + "custom_endpoint", None + ) base_model: Optional[str] = optional_params.pop("base_model", None) api_base, headers = self._validate_environment( api_base=api_base, api_key=api_key, endpoint_type="chat_completions", custom_endpoint=custom_endpoint, + headers=headers, ) ## Load Config config = litellm.DatabricksConfig().get_config() @@ -382,7 +391,7 @@ class DatabricksChatCompletion(BaseLLM): ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v - stream: bool = optional_params.pop("stream", None) or False + stream: bool = optional_params.get("stream", None) or False optional_params["stream"] = stream data = { @@ -565,12 +574,14 @@ class DatabricksChatCompletion(BaseLLM): model_response: Optional[litellm.utils.EmbeddingResponse] = None, client=None, aembedding=None, + headers: Optional[dict] = None, ) -> EmbeddingResponse: api_base, headers = self._validate_environment( api_base=api_base, api_key=api_key, endpoint_type="embeddings", custom_endpoint=False, + headers=headers, ) model = model data = {"model": model, "input": input, **optional_params} diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 76c3460f1..33be2efb8 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -206,17 +206,60 @@ class SagemakerLLM(BaseAWSLLM): print_verbose: Callable, encoding, logging_obj, + timeout: Optional[Union[float, httpx.Timeout]] = None, custom_prompt_dict={}, hf_model_name=None, optional_params=None, litellm_params=None, logger_fn=None, acompletion: bool = False, + use_messages_api: Optional[bool] = None, ): # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker credentials, aws_region_name = self._load_credentials(optional_params) inference_params = deepcopy(optional_params) + stream = inference_params.pop("stream", None) + model_id = optional_params.get("model_id", None) + + if use_messages_api is True: + from litellm.llms.databricks import DatabricksChatCompletion + + openai_like_chat_completions = DatabricksChatCompletion() + inference_params["stream"] = True if stream is True else False + _data = { + "model": model, + "messages": messages, + **inference_params, + } + + prepared_request = self._prepare_request( + model=model, + data=_data, + optional_params=optional_params, + credentials=credentials, + aws_region_name=aws_region_name, + ) + + return openai_like_chat_completions.completion( + model=model, + messages=messages, + api_base=prepared_request.url, + api_key=None, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + logging_obj=logging_obj, + optional_params=inference_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, + encoding=encoding, + headers=prepared_request.headers, + custom_endpoint=True, + custom_llm_provider="sagemaker_chat", + ) ## Load Config config = litellm.SagemakerConfig.get_config() @@ -259,8 +302,6 @@ class SagemakerLLM(BaseAWSLLM): hf_model_name or model ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) prompt = prompt_factory(model=hf_model_name, messages=messages) - stream = inference_params.pop("stream", None) - model_id = optional_params.get("model_id", None) if stream is True: data = {"inputs": prompt, "parameters": inference_params, "stream": True} @@ -275,7 +316,7 @@ class SagemakerLLM(BaseAWSLLM): # Add model_id as InferenceComponentName header # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html prepared_request.headers.update( - {"X-Amzn-SageMaker-Inference-Componen": model_id} + {"X-Amzn-SageMaker-Inference-Component": model_id} ) if acompletion is True: @@ -338,7 +379,7 @@ class SagemakerLLM(BaseAWSLLM): ) # Async completion - if acompletion == True: + if acompletion is True: return self.async_completion( prepared_request=prepared_request, model_response=model_response, @@ -354,7 +395,7 @@ class SagemakerLLM(BaseAWSLLM): # Add model_id as InferenceComponentName header # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html prepared_request.headers.update( - {"X-Amzn-SageMaker-Inference-Componen": model_id} + {"X-Amzn-SageMaker-Inference-Component": model_id} ) ## LOGGING diff --git a/litellm/main.py b/litellm/main.py index 8104bfd86..a2275cbe5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -381,6 +381,7 @@ async def acompletion( or custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini" or custom_llm_provider == "sagemaker" + or custom_llm_provider == "sagemaker_chat" or custom_llm_provider == "anthropic" or custom_llm_provider == "predibase" or custom_llm_provider == "bedrock" @@ -945,7 +946,6 @@ def completion( text_completion=kwargs.get("text_completion"), azure_ad_token_provider=kwargs.get("azure_ad_token_provider"), user_continue_message=kwargs.get("user_continue_message"), - ) logging.update_environment_variables( model=model, @@ -2247,7 +2247,10 @@ def completion( ## RESPONSE OBJECT response = model_response - elif custom_llm_provider == "sagemaker": + elif ( + custom_llm_provider == "sagemaker" + or custom_llm_provider == "sagemaker_chat" + ): # boto3 reads keys from .env model_response = sagemaker_llm.completion( model=model, @@ -2262,6 +2265,9 @@ def completion( encoding=encoding, logging_obj=logging, acompletion=acompletion, + use_messages_api=( + True if custom_llm_provider == "sagemaker_chat" else False + ), ) if optional_params.get("stream", False): ## LOGGING diff --git a/litellm/tests/test_sagemaker.py b/litellm/tests/test_sagemaker.py index 3d0d0d0cb..147c06c57 100644 --- a/litellm/tests/test_sagemaker.py +++ b/litellm/tests/test_sagemaker.py @@ -84,6 +84,40 @@ async def test_completion_sagemaker(sync_mode): pytest.fail(f"Error occurred: {e}") +@pytest.mark.asyncio() +@pytest.mark.parametrize( + "sync_mode", + [True, False], +) +async def test_completion_sagemaker_messages_api(sync_mode): + try: + litellm.set_verbose = True + verbose_logger.setLevel(logging.DEBUG) + print("testing sagemaker") + if sync_mode is True: + resp = litellm.completion( + model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + ) + print(resp) + else: + resp = await litellm.acompletion( + model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + ) + print(resp) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + @pytest.mark.asyncio() @pytest.mark.parametrize("sync_mode", [False, True]) async def test_completion_sagemaker_stream(sync_mode): diff --git a/litellm/utils.py b/litellm/utils.py index ec85d6ca1..92514a6d7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10611,6 +10611,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "vertex_ai" or self.custom_llm_provider == "vertex_ai_beta" or self.custom_llm_provider == "sagemaker" + or self.custom_llm_provider == "sagemaker_chat" or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "replicate" or self.custom_llm_provider == "cached_response"