fix(bedrock.py): adding support for cohere embeddings

This commit is contained in:
Krrish Dholakia 2023-12-06 13:24:49 -08:00
parent cf6ecc03a5
commit d962d5d4c0
4 changed files with 85 additions and 38 deletions

View file

@ -339,7 +339,7 @@ cohere_embedding_models: List = [
"embed-english-light-v2.0", "embed-english-light-v2.0",
"embed-multilingual-v2.0", "embed-multilingual-v2.0",
] ]
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1"] bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models

View file

@ -2,7 +2,7 @@ import json, copy, types
import os import os
from enum import Enum from enum import Enum
import time import time
from typing import Callable, Optional from typing import Callable, Optional, Any
import litellm import litellm
from litellm.utils import ModelResponse, get_secret, Usage from litellm.utils import ModelResponse, get_secret, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -205,15 +205,25 @@ class AmazonLlamaConfig():
def init_bedrock_client( def init_bedrock_client(
region_name = None, region_name = None,
aws_access_key_id = None, aws_access_key_id: Optional[str] = None,
aws_secret_access_key = None, aws_secret_access_key: Optional[str] = None,
aws_region_name=None, aws_region_name: Optional[str] =None,
aws_bedrock_runtime_endpoint=None, aws_bedrock_runtime_endpoint: Optional[str]=None,
): ):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME") litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION") standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith('os.environ/'):
params_to_check[i] = get_secret(param)
# Assign updated values back to parameters
aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check
if region_name: if region_name:
pass pass
elif aws_region_name: elif aws_region_name:
@ -533,36 +543,55 @@ def completion(
def _embedding_func_single( def _embedding_func_single(
model: str, model: str,
input: str, input: str,
client: Any,
optional_params=None, optional_params=None,
encoding=None, encoding=None,
logging_obj=None,
): ):
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them ## FORMAT EMBEDDING INPUT ##
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) provider = model.split(".")[0]
aws_access_key_id = optional_params.pop("aws_access_key_id", None) inference_params = copy.deepcopy(optional_params)
aws_region_name = optional_params.pop("aws_region_name", None) if provider == "amazon":
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop(
"aws_bedrock_client",
# only pass variables that are not None
init_bedrock_client(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
),
)
input = input.replace(os.linesep, " ") input = input.replace(os.linesep, " ")
body = json.dumps({"inputText": input}) data = {"inputText": input, **inference_params}
# data = json.dumps(data)
elif provider == "cohere":
inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
data = {"texts": [input], **inference_params}
body = json.dumps(data).encode("utf-8")
## LOGGING
request_str = f"""
response = client.invoke_model(
body={body},
modelId={model},
accept="*/*",
contentType="application/json",
)""" # type: ignore
logging_obj.pre_call(
input=input,
api_key="", # boto3 is used for init.
additional_args={"complete_input_dict": {"model": model,
"texts": input}, "request_str": request_str},
)
try: try:
response = client.invoke_model( response = client.invoke_model(
body=body, body=body,
modelId=model, modelId="cohere.embed-multilingual-v3",
accept="application/json", accept="*/*",
contentType="application/json", contentType="application/json",
) )
response_body = json.loads(response.get("body").read()) response_body = json.loads(response.get("body").read())
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data},
original_response=response_body,
)
if provider == "cohere":
return response_body.get("embeddings")
elif provider == "amazon":
return response_body.get("embedding") return response_body.get("embedding")
except Exception as e: except Exception as e:
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500) raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
@ -576,17 +605,21 @@ def embedding(
optional_params=None, optional_params=None,
encoding=None, encoding=None,
): ):
### BOTO3 INIT ###
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
## LOGGING # use passed in BedrockRuntime.Client if provided, otherwise create a new one
logging_obj.pre_call( client = init_bedrock_client(
input=input, aws_access_key_id=aws_access_key_id,
api_key=api_key, aws_secret_access_key=aws_secret_access_key,
additional_args={"complete_input_dict": {"model": model, aws_region_name=aws_region_name,
"texts": input}},
) )
## Embedding Call ## Embedding Call
embeddings = [_embedding_func_single(model, i, optional_params) for i in input] embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls
## Populate OpenAI compliant dictionary ## Populate OpenAI compliant dictionary

View file

@ -843,6 +843,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
data["user"] = user_api_key_dict.user_id data["user"] = user_api_key_dict.user_id
if "metadata" in data: if "metadata" in data:
print(f'received metadata: {data["metadata"]}')
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = request.headers data["metadata"]["headers"] = request.headers
else: else:

View file

@ -151,7 +151,7 @@ def test_cohere_embedding3():
# test_cohere_embedding3() # test_cohere_embedding3()
def test_bedrock_embedding(): def test_bedrock_embedding_titan():
try: try:
response = embedding( response = embedding(
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data", model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
@ -162,6 +162,19 @@ def test_bedrock_embedding():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_bedrock_embedding() # test_bedrock_embedding()
def test_bedrock_embedding_cohere():
try:
# litellm.set_verbose=True
response = embedding(
model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"],
aws_region_name="os.environ/AWS_REGION_NAME_2"
)
# print(f"response:", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_bedrock_embedding_cohere()
# comment out hf tests - since hf endpoints are unstable # comment out hf tests - since hf endpoints are unstable
def test_hf_embedding(): def test_hf_embedding():
try: try: