diff --git a/litellm/__init__.py b/litellm/__init__.py index aedf421391..67170c68d0 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -670,6 +670,7 @@ from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig +from .llms.bedrock_httpx import AmazonCohereChatConfig from .llms.bedrock import ( AmazonTitanConfig, AmazonAI21Config, diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index c6b0327e6b..d3062b5ed8 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -7,7 +7,7 @@ import json from enum import Enum import requests, copy # type: ignore import time -from typing import Callable, Optional, List, Literal, Union +from typing import Callable, Optional, List, Literal, Union, Any, TypedDict, Tuple from litellm.utils import ( ModelResponse, Usage, @@ -18,11 +18,110 @@ from litellm.utils import ( get_secret, ) import litellm -from .prompt_templates.factory import prompt_factory, custom_prompt -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from .base import BaseLLM import httpx # type: ignore -from .bedrock import BedrockError +from .bedrock import BedrockError, convert_messages_to_prompt +from litellm.types.llms.bedrock import * + + +class AmazonCohereChatConfig: + """ + Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html + """ + + documents: Optional[List[Document]] = None + search_queries_only: Optional[bool] = None + preamble: Optional[str] = None + max_tokens: Optional[int] = None + temperature: Optional[float] = None + p: Optional[float] = None + k: Optional[float] = None + prompt_truncation: Optional[str] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + seed: Optional[int] = None + return_prompt: Optional[bool] = None + stop_sequences: Optional[List[str]] = None + raw_prompting: Optional[bool] = None + + def __init__( + self, + documents: Optional[List[Document]] = None, + search_queries_only: Optional[bool] = None, + preamble: Optional[str] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + p: Optional[float] = None, + k: Optional[float] = None, + prompt_truncation: Optional[str] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + return_prompt: Optional[bool] = None, + stop_sequences: Optional[str] = None, + raw_prompting: Optional[bool] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self) -> List[str]: + return [ + "max_tokens", + "stream", + "stop", + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + "seed", + "stop", + ] + + def map_openai_params( + self, non_default_params: dict, optional_params: dict + ) -> dict: + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "stream": + optional_params["stream"] = value + if param == "stop": + if isinstance(value, str): + value = [value] + optional_params["stop_sequences"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["p"] = value + if param == "frequency_penalty": + optional_params["frequency_penalty"] = value + if param == "presence_penalty": + optional_params["presence_penalty"] = value + if "seed": + optional_params["seed"] = value + return optional_params class BedrockLLM(BaseLLM): @@ -47,6 +146,48 @@ class BedrockLLM(BaseLLM): def __init__(self) -> None: super().__init__() + def convert_messages_to_prompt( + self, model, messages, provider, custom_prompt_dict + ) -> Tuple[str, Optional[list]]: + # handle anthropic prompts and amazon titan prompts + prompt = "" + chat_history: Optional[list] = None + if provider == "anthropic" or provider == "amazon": + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + prompt = prompt_factory( + model=model, messages=messages, custom_llm_provider="bedrock" + ) + elif provider == "mistral": + prompt = prompt_factory( + model=model, messages=messages, custom_llm_provider="bedrock" + ) + elif provider == "meta": + prompt = prompt_factory( + model=model, messages=messages, custom_llm_provider="bedrock" + ) + elif provider == "cohere": + prompt, chat_history = cohere_message_pt(messages=messages) + else: + prompt = "" + for message in messages: + if "role" in message: + if message["role"] == "user": + prompt += f"{message['content']}" + else: + prompt += f"{message['content']}" + else: + prompt += f"{message['content']}" + return prompt, chat_history # type: ignore + def get_credentials( self, aws_access_key_id: Optional[str] = None, @@ -114,11 +255,168 @@ class BedrockLLM(BaseLLM): return session.get_credentials() - def completion(self, *args, **kwargs) -> Union[ModelResponse, CustomStreamWrapper]: - ## get credentials - ## generate signature - ## make request - return super().completion(*args, **kwargs) + def completion( + self, + model: str, + messages: list, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + timeout: Optional[Union[float, httpx.Timeout]], + litellm_params=None, + logger_fn=None, + extra_headers: Optional[dict] = None, + client: Optional[HTTPHandler] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + try: + import boto3 + + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError as e: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + ) + + ### SET RUNTIME ENDPOINT ### + endpoint_url = "" + env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") + if aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + endpoint_url = aws_bedrock_runtime_endpoint + elif env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + endpoint_url = env_aws_bedrock_runtime_endpoint + else: + endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + + endpoint_url = f"{endpoint_url}/model/{model}/invoke" + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + + provider = model.split(".")[0] + prompt, chat_history = self.convert_messages_to_prompt( + model, messages, provider, custom_prompt_dict + ) + inference_params = copy.deepcopy(optional_params) + stream = inference_params.pop("stream", False) + + if provider == "cohere": + if model.startswith("cohere.command-r"): + ## LOAD CONFIG + config = litellm.AmazonCohereChatConfig().get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + if optional_params.get("stream", False) == True: + inference_params["stream"] = ( + True # cohere requires stream = True in inference params + ) + + _data = {"message": prompt, **inference_params} + if chat_history is not None: + _data["chat_history"] = chat_history + data = json.dumps(_data) + else: + ## LOAD CONFIG + config = litellm.AmazonCohereConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + if optional_params.get("stream", False) == True: + inference_params["stream"] = ( + True # cohere requires stream = True in inference params + ) + data = json.dumps({"prompt": prompt, **inference_params}) + else: + raise Exception("UNSUPPORTED PROVIDER") + + ## COMPLETION CALL + headers = {"Content-Type": "application/json"} + request = AWSRequest( + method="POST", url=endpoint_url, data=data, headers=headers + ) + sigv4.add_auth(request) + prepped = request.prepare() + + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + self.client = HTTPHandler(**_params) # type: ignore + else: + self.client = client + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": prepped.url, + "headers": prepped.headers, + }, + ) + + response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore + + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=response.text) + + return response def embedding(self, *args, **kwargs): return super().embedding(*args, **kwargs) diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 7c7d4938a4..529ba3b390 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -58,16 +58,25 @@ class AsyncHTTPHandler: class HTTPHandler: def __init__( - self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000 + self, + timeout: Optional[httpx.Timeout] = None, + concurrent_limit=1000, + client: Optional[httpx.Client] = None, ): - # Create a client with a connection pool - self.client = httpx.Client( - timeout=timeout, - limits=httpx.Limits( - max_connections=concurrent_limit, - max_keepalive_connections=concurrent_limit, - ), - ) + if timeout is None: + timeout = _DEFAULT_TIMEOUT + + if client is None: + # Create a client with a connection pool + self.client = httpx.Client( + timeout=timeout, + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ), + ) + else: + self.client = client def close(self): # Close the client when you're done with it diff --git a/litellm/main.py b/litellm/main.py index 8be71de0b7..d2f3939fde 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1922,20 +1922,37 @@ def completion( elif custom_llm_provider == "bedrock": # boto3 reads keys from .env custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = bedrock.completion( - model=model, - messages=messages, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - ) + + if "cohere" in model: + response = bedrock_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + ) + else: + response = bedrock.completion( + model=model, + messages=messages, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + ) if ( "stream" in optional_params diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 214dc105b1..0cf6dda835 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2585,6 +2585,7 @@ def test_completion_chat_sagemaker_mistral(): def test_completion_bedrock_command_r(): + litellm.set_verbose = True response = completion( model="bedrock/cohere.command-r-plus-v1:0", messages=[{"role": "user", "content": "Hey! how's it going?"}], diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py new file mode 100644 index 0000000000..87ef6fd3cc --- /dev/null +++ b/litellm/types/llms/bedrock.py @@ -0,0 +1,6 @@ +from typing import TypedDict + + +class Document(TypedDict): + title: str + snippet: str