feat(bedrock_httpx.py): working cohere command r async calls

This commit is contained in:
Krrish Dholakia 2024-05-11 15:04:38 -07:00
parent 926b86af87
commit 5185580e3d
6 changed files with 364 additions and 32 deletions

View file

@ -670,6 +670,7 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig
from .llms.bedrock import ( from .llms.bedrock import (
AmazonTitanConfig, AmazonTitanConfig,
AmazonAI21Config, AmazonAI21Config,

View file

@ -7,7 +7,7 @@ import json
from enum import Enum from enum import Enum
import requests, copy # type: ignore import requests, copy # type: ignore
import time import time
from typing import Callable, Optional, List, Literal, Union from typing import Callable, Optional, List, Literal, Union, Any, TypedDict, Tuple
from litellm.utils import ( from litellm.utils import (
ModelResponse, ModelResponse,
Usage, Usage,
@ -18,11 +18,110 @@ from litellm.utils import (
get_secret, get_secret,
) )
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM from .base import BaseLLM
import httpx # type: ignore import httpx # type: ignore
from .bedrock import BedrockError from .bedrock import BedrockError, convert_messages_to_prompt
from litellm.types.llms.bedrock import *
class AmazonCohereChatConfig:
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
"""
documents: Optional[List[Document]] = None
search_queries_only: Optional[bool] = None
preamble: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
p: Optional[float] = None
k: Optional[float] = None
prompt_truncation: Optional[str] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
seed: Optional[int] = None
return_prompt: Optional[bool] = None
stop_sequences: Optional[List[str]] = None
raw_prompting: Optional[bool] = None
def __init__(
self,
documents: Optional[List[Document]] = None,
search_queries_only: Optional[bool] = None,
preamble: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
p: Optional[float] = None,
k: Optional[float] = None,
prompt_truncation: Optional[str] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
return_prompt: Optional[bool] = None,
stop_sequences: Optional[str] = None,
raw_prompting: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self) -> List[str]:
return [
"max_tokens",
"stream",
"stop",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
"seed",
"stop",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
value = [value]
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["p"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if "seed":
optional_params["seed"] = value
return optional_params
class BedrockLLM(BaseLLM): class BedrockLLM(BaseLLM):
@ -47,6 +146,48 @@ class BedrockLLM(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def convert_messages_to_prompt(
self, model, messages, provider, custom_prompt_dict
) -> Tuple[str, Optional[list]]:
# handle anthropic prompts and amazon titan prompts
prompt = ""
chat_history: Optional[list] = None
if provider == "anthropic" or provider == "amazon":
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "mistral":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "cohere":
prompt, chat_history = cohere_message_pt(messages=messages)
else:
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
return prompt, chat_history # type: ignore
def get_credentials( def get_credentials(
self, self,
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
@ -114,11 +255,168 @@ class BedrockLLM(BaseLLM):
return session.get_credentials() return session.get_credentials()
def completion(self, *args, **kwargs) -> Union[ModelResponse, CustomStreamWrapper]: def completion(
## get credentials self,
## generate signature model: str,
## make request messages: list,
return super().completion(*args, **kwargs) custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[HTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
)
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
provider = model.split(".")[0]
prompt, chat_history = self.convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", False)
if provider == "cohere":
if model.startswith("cohere.command-r"):
## LOAD CONFIG
config = litellm.AmazonCohereChatConfig().get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if optional_params.get("stream", False) == True:
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
_data = {"message": prompt, **inference_params}
if chat_history is not None:
_data["chat_history"] = chat_history
data = json.dumps(_data)
else:
## LOAD CONFIG
config = litellm.AmazonCohereConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if optional_params.get("stream", False) == True:
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
data = json.dumps({"prompt": prompt, **inference_params})
else:
raise Exception("UNSUPPORTED PROVIDER")
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
prepped = request.prepare()
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = HTTPHandler(**_params) # type: ignore
else:
self.client = client
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
try:
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
return response
def embedding(self, *args, **kwargs): def embedding(self, *args, **kwargs):
return super().embedding(*args, **kwargs) return super().embedding(*args, **kwargs)

View file

@ -58,16 +58,25 @@ class AsyncHTTPHandler:
class HTTPHandler: class HTTPHandler:
def __init__( def __init__(
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000 self,
timeout: Optional[httpx.Timeout] = None,
concurrent_limit=1000,
client: Optional[httpx.Client] = None,
): ):
# Create a client with a connection pool if timeout is None:
self.client = httpx.Client( timeout = _DEFAULT_TIMEOUT
timeout=timeout,
limits=httpx.Limits( if client is None:
max_connections=concurrent_limit, # Create a client with a connection pool
max_keepalive_connections=concurrent_limit, self.client = httpx.Client(
), timeout=timeout,
) limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
else:
self.client = client
def close(self): def close(self):
# Close the client when you're done with it # Close the client when you're done with it

View file

@ -1922,20 +1922,37 @@ def completion(
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
# boto3 reads keys from .env # boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = bedrock.completion(
model=model, if "cohere" in model:
messages=messages, response = bedrock_chat_completion.completion(
custom_prompt_dict=litellm.custom_prompt_dict, model=model,
model_response=model_response, messages=messages,
print_verbose=print_verbose, custom_prompt_dict=litellm.custom_prompt_dict,
optional_params=optional_params, model_response=model_response,
litellm_params=litellm_params, print_verbose=print_verbose,
logger_fn=logger_fn, optional_params=optional_params,
encoding=encoding, litellm_params=litellm_params,
logging_obj=logging, logger_fn=logger_fn,
extra_headers=extra_headers, encoding=encoding,
timeout=timeout, logging_obj=logging,
) extra_headers=extra_headers,
timeout=timeout,
)
else:
response = bedrock.completion(
model=model,
messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
)
if ( if (
"stream" in optional_params "stream" in optional_params

View file

@ -2585,6 +2585,7 @@ def test_completion_chat_sagemaker_mistral():
def test_completion_bedrock_command_r(): def test_completion_bedrock_command_r():
litellm.set_verbose = True
response = completion( response = completion(
model="bedrock/cohere.command-r-plus-v1:0", model="bedrock/cohere.command-r-plus-v1:0",
messages=[{"role": "user", "content": "Hey! how's it going?"}], messages=[{"role": "user", "content": "Hey! how's it going?"}],

View file

@ -0,0 +1,6 @@
from typing import TypedDict
class Document(TypedDict):
title: str
snippet: str