mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
refactor: add black formatting
This commit is contained in:
parent
b87d630b0a
commit
4905929de3
156 changed files with 19723 additions and 10869 deletions
|
@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, get_secret, Usage
|
|||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
import httpx
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock")
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class AmazonTitanConfig():
|
||||
|
||||
class AmazonTitanConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||
|
||||
|
@ -29,29 +33,44 @@ class AmazonTitanConfig():
|
|||
- `temperature` (float) temperature for model,
|
||||
- `topP` (int) top p for model
|
||||
"""
|
||||
maxTokenCount: Optional[int]=None
|
||||
stopSequences: Optional[list]=None
|
||||
temperature: Optional[float]=None
|
||||
topP: Optional[int]=None
|
||||
|
||||
def __init__(self,
|
||||
maxTokenCount: Optional[int]=None,
|
||||
stopSequences: Optional[list]=None,
|
||||
temperature: Optional[float]=None,
|
||||
topP: Optional[int]=None) -> None:
|
||||
maxTokenCount: Optional[int] = None
|
||||
stopSequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
class AmazonAnthropicConfig():
|
||||
|
||||
class AmazonAnthropicConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
|
||||
|
@ -64,33 +83,48 @@ class AmazonAnthropicConfig():
|
|||
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
||||
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
"""
|
||||
max_tokens_to_sample: Optional[int]=litellm.max_tokens
|
||||
stop_sequences: Optional[list]=None
|
||||
temperature: Optional[float]=None
|
||||
top_k: Optional[int]=None
|
||||
top_p: Optional[int]=None
|
||||
anthropic_version: Optional[str]=None
|
||||
|
||||
def __init__(self,
|
||||
max_tokens_to_sample: Optional[int]=None,
|
||||
stop_sequences: Optional[list]=None,
|
||||
temperature: Optional[float]=None,
|
||||
top_k: Optional[int]=None,
|
||||
top_p: Optional[int]=None,
|
||||
anthropic_version: Optional[str]=None) -> None:
|
||||
max_tokens_to_sample: Optional[int] = litellm.max_tokens
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
anthropic_version: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
class AmazonCohereConfig():
|
||||
|
||||
class AmazonCohereConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||
|
||||
|
@ -100,79 +134,110 @@ class AmazonCohereConfig():
|
|||
- `temperature` (float) model temperature,
|
||||
- `return_likelihood` (string) n/a
|
||||
"""
|
||||
max_tokens: Optional[int]=None
|
||||
temperature: Optional[float]=None
|
||||
return_likelihood: Optional[str]=None
|
||||
|
||||
def __init__(self,
|
||||
max_tokens: Optional[int]=None,
|
||||
temperature: Optional[float]=None,
|
||||
return_likelihood: Optional[str]=None) -> None:
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
return_likelihood: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
return_likelihood: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
class AmazonAI21Config():
|
||||
|
||||
class AmazonAI21Config:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
|
||||
Supported Params for the Amazon / AI21 models:
|
||||
|
||||
|
||||
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
||||
|
||||
|
||||
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
||||
|
||||
|
||||
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
||||
|
||||
|
||||
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
||||
|
||||
|
||||
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
||||
|
||||
|
||||
- `presencePenalty` (object): Placeholder for presence penalty object.
|
||||
|
||||
|
||||
- `countPenalty` (object): Placeholder for count penalty object.
|
||||
"""
|
||||
maxTokens: Optional[int]=None
|
||||
temperature: Optional[float]=None
|
||||
topP: Optional[float]=None
|
||||
stopSequences: Optional[list]=None
|
||||
frequencePenalty: Optional[dict]=None
|
||||
presencePenalty: Optional[dict]=None
|
||||
countPenalty: Optional[dict]=None
|
||||
|
||||
def __init__(self,
|
||||
maxTokens: Optional[int]=None,
|
||||
temperature: Optional[float]=None,
|
||||
topP: Optional[float]=None,
|
||||
stopSequences: Optional[list]=None,
|
||||
frequencePenalty: Optional[dict]=None,
|
||||
presencePenalty: Optional[dict]=None,
|
||||
countPenalty: Optional[dict]=None) -> None:
|
||||
maxTokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
stopSequences: Optional[list] = None
|
||||
frequencePenalty: Optional[dict] = None
|
||||
presencePenalty: Optional[dict] = None
|
||||
countPenalty: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[float] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
frequencePenalty: Optional[dict] = None,
|
||||
presencePenalty: Optional[dict] = None,
|
||||
countPenalty: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
class AmazonLlamaConfig():
|
||||
|
||||
class AmazonLlamaConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
||||
|
||||
|
@ -182,48 +247,72 @@ class AmazonLlamaConfig():
|
|||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
"""
|
||||
max_gen_len: Optional[int]=None
|
||||
temperature: Optional[float]=None
|
||||
topP: Optional[float]=None
|
||||
|
||||
def __init__(self,
|
||||
maxTokenCount: Optional[int]=None,
|
||||
temperature: Optional[float]=None,
|
||||
topP: Optional[int]=None) -> None:
|
||||
max_gen_len: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def init_bedrock_client(
|
||||
region_name = None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] =None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str]=None,
|
||||
):
|
||||
region_name=None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
||||
):
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint]
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith('os.environ/'):
|
||||
if param and param.startswith("os.environ/"):
|
||||
params_to_check[i] = get_secret(param)
|
||||
# Assign updated values back to parameters
|
||||
aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
) = params_to_check
|
||||
if region_name:
|
||||
pass
|
||||
elif aws_region_name:
|
||||
|
@ -233,7 +322,10 @@ def init_bedrock_client(
|
|||
elif standard_aws_region_name:
|
||||
region_name = standard_aws_region_name
|
||||
else:
|
||||
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", status_code=401)
|
||||
raise BedrockError(
|
||||
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
|
@ -242,9 +334,10 @@ def init_bedrock_client(
|
|||
elif env_aws_bedrock_runtime_endpoint:
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f'https://bedrock-runtime.{region_name}.amazonaws.com'
|
||||
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
|
||||
|
||||
import boto3
|
||||
|
||||
if aws_access_key_id != None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
|
@ -257,7 +350,7 @@ def init_bedrock_client(
|
|||
endpoint_url=endpoint_url,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automatically reads env variables
|
||||
|
||||
client = boto3.client(
|
||||
|
@ -276,25 +369,23 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
|||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
return prompt
|
||||
|
@ -309,17 +400,18 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
|
|||
|
||||
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
|
@ -327,7 +419,9 @@ def completion(
|
|||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = optional_params.pop(
|
||||
|
@ -343,67 +437,71 @@ def completion(
|
|||
|
||||
model = model
|
||||
provider = model.split(".")[0]
|
||||
prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict)
|
||||
prompt = convert_messages_to_prompt(
|
||||
model, messages, provider, custom_prompt_dict
|
||||
)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
stream = inference_params.pop("stream", False)
|
||||
if provider == "anthropic":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.AmazonAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = json.dumps({
|
||||
"prompt": prompt,
|
||||
**inference_params
|
||||
})
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "ai21":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
data = json.dumps({
|
||||
"prompt": prompt,
|
||||
**inference_params
|
||||
})
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "cohere":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.AmazonCohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
if optional_params.get("stream", False) == True:
|
||||
inference_params["stream"] = True # cohere requires stream = True in inference params
|
||||
data = json.dumps({
|
||||
"prompt": prompt,
|
||||
**inference_params
|
||||
})
|
||||
inference_params[
|
||||
"stream"
|
||||
] = True # cohere requires stream = True in inference params
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "meta":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonLlamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = json.dumps({
|
||||
"prompt": prompt,
|
||||
**inference_params
|
||||
})
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "amazon": # amazon titan
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonTitanConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
config = litellm.AmazonTitanConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
data = json.dumps({
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": inference_params,
|
||||
})
|
||||
|
||||
data = json.dumps(
|
||||
{
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": inference_params,
|
||||
}
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
accept = 'application/json'
|
||||
contentType = 'application/json'
|
||||
accept = "application/json"
|
||||
contentType = "application/json"
|
||||
if stream == True:
|
||||
if provider == "ai21":
|
||||
## LOGGING
|
||||
|
@ -418,17 +516,17 @@ def completion(
|
|||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
response = client.invoke_model(
|
||||
body=data,
|
||||
modelId=model,
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
body=data, modelId=model, accept=accept, contentType=contentType
|
||||
)
|
||||
|
||||
response = response.get('body').read()
|
||||
response = response.get("body").read()
|
||||
return response
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -441,20 +539,20 @@ def completion(
|
|||
)
|
||||
"""
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
response = client.invoke_model_with_response_stream(
|
||||
body=data,
|
||||
modelId=model,
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
body=data, modelId=model, accept=accept, contentType=contentType
|
||||
)
|
||||
response = response.get('body')
|
||||
response = response.get("body")
|
||||
return response
|
||||
try:
|
||||
try:
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
|
@ -465,20 +563,20 @@ def completion(
|
|||
)
|
||||
"""
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
)
|
||||
response = client.invoke_model(
|
||||
body=data,
|
||||
modelId=model,
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
response = client.invoke_model(
|
||||
body=data, modelId=model, accept=accept, contentType=contentType
|
||||
)
|
||||
except Exception as e:
|
||||
raise BedrockError(status_code=500, message=str(e))
|
||||
|
||||
response_body = json.loads(response.get('body').read())
|
||||
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -491,16 +589,16 @@ def completion(
|
|||
## RESPONSE OBJECT
|
||||
outputText = "default"
|
||||
if provider == "ai21":
|
||||
outputText = response_body.get('completions')[0].get('data').get('text')
|
||||
outputText = response_body.get("completions")[0].get("data").get("text")
|
||||
elif provider == "anthropic":
|
||||
outputText = response_body['completion']
|
||||
outputText = response_body["completion"]
|
||||
model_response["finish_reason"] = response_body["stop_reason"]
|
||||
elif provider == "cohere":
|
||||
elif provider == "cohere":
|
||||
outputText = response_body["generations"][0]["text"]
|
||||
elif provider == "meta":
|
||||
elif provider == "meta":
|
||||
outputText = response_body["generation"]
|
||||
else: # amazon titan
|
||||
outputText = response_body.get('results')[0].get('outputText')
|
||||
outputText = response_body.get("results")[0].get("outputText")
|
||||
|
||||
response_metadata = response.get("ResponseMetadata", {})
|
||||
if response_metadata.get("HTTPStatusCode", 500) >= 400:
|
||||
|
@ -513,12 +611,13 @@ def completion(
|
|||
if len(outputText) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = outputText
|
||||
except:
|
||||
raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500))
|
||||
raise BedrockError(
|
||||
message=json.dumps(outputText),
|
||||
status_code=response_metadata.get("HTTPStatusCode", 500),
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
@ -528,41 +627,47 @@ def completion(
|
|||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
except BedrockError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
else:
|
||||
else:
|
||||
import traceback
|
||||
|
||||
raise BedrockError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
|
||||
def _embedding_func_single(
|
||||
model: str,
|
||||
input: str,
|
||||
client: Any,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
logging_obj=None,
|
||||
model: str,
|
||||
input: str,
|
||||
client: Any,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
logging_obj=None,
|
||||
):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
## FORMAT EMBEDDING INPUT ##
|
||||
## FORMAT EMBEDDING INPUT ##
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop("user", None) # make sure user is not passed in for bedrock call
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
if provider == "amazon":
|
||||
input = input.replace(os.linesep, " ")
|
||||
data = {"inputText": input, **inference_params}
|
||||
# data = json.dumps(data)
|
||||
elif provider == "cohere":
|
||||
inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||
data = {"texts": [input], **inference_params} # type: ignore
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
inference_params["input_type"] = inference_params.get(
|
||||
"input_type", "search_document"
|
||||
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||
data = {"texts": [input], **inference_params} # type: ignore
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
|
@ -570,12 +675,14 @@ def _embedding_func_single(
|
|||
modelId={model},
|
||||
accept="*/*",
|
||||
contentType="application/json",
|
||||
)""" # type: ignore
|
||||
)""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="", # boto3 is used for init.
|
||||
additional_args={"complete_input_dict": {"model": model,
|
||||
"texts": input}, "request_str": request_str},
|
||||
api_key="", # boto3 is used for init.
|
||||
additional_args={
|
||||
"complete_input_dict": {"model": model, "texts": input},
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
try:
|
||||
response = client.invoke_model(
|
||||
|
@ -587,11 +694,11 @@ def _embedding_func_single(
|
|||
response_body = json.loads(response.get("body").read())
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=json.dumps(response_body),
|
||||
)
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=json.dumps(response_body),
|
||||
)
|
||||
if provider == "cohere":
|
||||
response = response_body.get("embeddings")
|
||||
# flatten list
|
||||
|
@ -600,7 +707,10 @@ def _embedding_func_single(
|
|||
elif provider == "amazon":
|
||||
return response_body.get("embedding")
|
||||
except Exception as e:
|
||||
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
|
||||
raise BedrockError(
|
||||
message=f"Embedding Error with model {model}: {e}", status_code=500
|
||||
)
|
||||
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
|
@ -616,7 +726,9 @@ def embedding(
|
|||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
|
@ -624,11 +736,19 @@ def embedding(
|
|||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
)
|
||||
)
|
||||
|
||||
## Embedding Call
|
||||
embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls
|
||||
|
||||
embeddings = [
|
||||
_embedding_func_single(
|
||||
model,
|
||||
i,
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
for i in input
|
||||
] # [TODO]: make these parallel calls
|
||||
|
||||
## Populate OpenAI compliant dictionary
|
||||
embedding_response = []
|
||||
|
@ -647,13 +767,11 @@ def embedding(
|
|||
|
||||
input_str = "".join(input)
|
||||
|
||||
input_tokens+=len(encoding.encode(input_str))
|
||||
input_tokens += len(encoding.encode(input_str))
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=input_tokens + 0
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + 0
|
||||
)
|
||||
model_response.usage = usage
|
||||
|
||||
|
||||
return model_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue