forked from phoenix/litellm-mirror
fix(bedrock.py): adding support for cohere embeddings
This commit is contained in:
parent
cf6ecc03a5
commit
d962d5d4c0
4 changed files with 85 additions and 38 deletions
|
@ -339,7 +339,7 @@ cohere_embedding_models: List = [
|
||||||
"embed-english-light-v2.0",
|
"embed-english-light-v2.0",
|
||||||
"embed-multilingual-v2.0",
|
"embed-multilingual-v2.0",
|
||||||
]
|
]
|
||||||
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1"]
|
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]
|
||||||
|
|
||||||
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
|
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json, copy, types
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, Any
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, get_secret, Usage
|
from litellm.utils import ModelResponse, get_secret, Usage
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
@ -205,15 +205,25 @@ class AmazonLlamaConfig():
|
||||||
|
|
||||||
def init_bedrock_client(
|
def init_bedrock_client(
|
||||||
region_name = None,
|
region_name = None,
|
||||||
aws_access_key_id = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
aws_secret_access_key = None,
|
aws_secret_access_key: Optional[str] = None,
|
||||||
aws_region_name=None,
|
aws_region_name: Optional[str] =None,
|
||||||
aws_bedrock_runtime_endpoint=None,
|
aws_bedrock_runtime_endpoint: Optional[str]=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME")
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
standard_aws_region_name = get_secret("AWS_REGION")
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
|
||||||
|
## CHECK IS 'os.environ/' passed in
|
||||||
|
# Define the list of parameters to check
|
||||||
|
params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint]
|
||||||
|
|
||||||
|
# Iterate over parameters and update if needed
|
||||||
|
for i, param in enumerate(params_to_check):
|
||||||
|
if param and param.startswith('os.environ/'):
|
||||||
|
params_to_check[i] = get_secret(param)
|
||||||
|
# Assign updated values back to parameters
|
||||||
|
aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check
|
||||||
if region_name:
|
if region_name:
|
||||||
pass
|
pass
|
||||||
elif aws_region_name:
|
elif aws_region_name:
|
||||||
|
@ -533,37 +543,56 @@ def completion(
|
||||||
def _embedding_func_single(
|
def _embedding_func_single(
|
||||||
model: str,
|
model: str,
|
||||||
input: str,
|
input: str,
|
||||||
|
client: Any,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
|
logging_obj=None,
|
||||||
):
|
):
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
# logic for parsing in - calling - parsing out model embedding calls
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
## FORMAT EMBEDDING INPUT ##
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
provider = model.split(".")[0]
|
||||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
if provider == "amazon":
|
||||||
|
input = input.replace(os.linesep, " ")
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
data = {"inputText": input, **inference_params}
|
||||||
client = optional_params.pop(
|
# data = json.dumps(data)
|
||||||
"aws_bedrock_client",
|
elif provider == "cohere":
|
||||||
# only pass variables that are not None
|
inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||||
init_bedrock_client(
|
data = {"texts": [input], **inference_params}
|
||||||
aws_access_key_id=aws_access_key_id,
|
body = json.dumps(data).encode("utf-8")
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
## LOGGING
|
||||||
aws_region_name=aws_region_name,
|
request_str = f"""
|
||||||
),
|
response = client.invoke_model(
|
||||||
|
body={body},
|
||||||
|
modelId={model},
|
||||||
|
accept="*/*",
|
||||||
|
contentType="application/json",
|
||||||
|
)""" # type: ignore
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key="", # boto3 is used for init.
|
||||||
|
additional_args={"complete_input_dict": {"model": model,
|
||||||
|
"texts": input}, "request_str": request_str},
|
||||||
)
|
)
|
||||||
|
|
||||||
input = input.replace(os.linesep, " ")
|
|
||||||
body = json.dumps({"inputText": input})
|
|
||||||
try:
|
try:
|
||||||
response = client.invoke_model(
|
response = client.invoke_model(
|
||||||
body=body,
|
body=body,
|
||||||
modelId=model,
|
modelId="cohere.embed-multilingual-v3",
|
||||||
accept="application/json",
|
accept="*/*",
|
||||||
contentType="application/json",
|
contentType="application/json",
|
||||||
)
|
)
|
||||||
response_body = json.loads(response.get("body").read())
|
response_body = json.loads(response.get("body").read())
|
||||||
return response_body.get("embedding")
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key="",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=response_body,
|
||||||
|
)
|
||||||
|
if provider == "cohere":
|
||||||
|
return response_body.get("embeddings")
|
||||||
|
elif provider == "amazon":
|
||||||
|
return response_body.get("embedding")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
|
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
|
||||||
|
|
||||||
|
@ -576,17 +605,21 @@ def embedding(
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
):
|
):
|
||||||
|
### BOTO3 INIT ###
|
||||||
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
|
||||||
## LOGGING
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
logging_obj.pre_call(
|
client = init_bedrock_client(
|
||||||
input=input,
|
aws_access_key_id=aws_access_key_id,
|
||||||
api_key=api_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
additional_args={"complete_input_dict": {"model": model,
|
aws_region_name=aws_region_name,
|
||||||
"texts": input}},
|
)
|
||||||
)
|
|
||||||
|
|
||||||
## Embedding Call
|
## Embedding Call
|
||||||
embeddings = [_embedding_func_single(model, i, optional_params) for i in input]
|
embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls
|
||||||
|
|
||||||
|
|
||||||
## Populate OpenAI compliant dictionary
|
## Populate OpenAI compliant dictionary
|
||||||
|
|
|
@ -843,6 +843,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
||||||
data["user"] = user_api_key_dict.user_id
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
|
print(f'received metadata: {data["metadata"]}')
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
data["metadata"]["headers"] = request.headers
|
data["metadata"]["headers"] = request.headers
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -151,7 +151,7 @@ def test_cohere_embedding3():
|
||||||
|
|
||||||
# test_cohere_embedding3()
|
# test_cohere_embedding3()
|
||||||
|
|
||||||
def test_bedrock_embedding():
|
def test_bedrock_embedding_titan():
|
||||||
try:
|
try:
|
||||||
response = embedding(
|
response = embedding(
|
||||||
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
|
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
|
||||||
|
@ -162,6 +162,19 @@ def test_bedrock_embedding():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_bedrock_embedding()
|
# test_bedrock_embedding()
|
||||||
|
|
||||||
|
def test_bedrock_embedding_cohere():
|
||||||
|
try:
|
||||||
|
# litellm.set_verbose=True
|
||||||
|
response = embedding(
|
||||||
|
model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"],
|
||||||
|
aws_region_name="os.environ/AWS_REGION_NAME_2"
|
||||||
|
)
|
||||||
|
# print(f"response:", response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
test_bedrock_embedding_cohere()
|
||||||
|
|
||||||
# comment out hf tests - since hf endpoints are unstable
|
# comment out hf tests - since hf endpoints are unstable
|
||||||
def test_hf_embedding():
|
def test_hf_embedding():
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue