mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Bedrock Embeddings refactor + model support (#5462)
* refactor(bedrock): initial commit to refactor bedrock to a folder Improve code readability + maintainability * refactor: more refactor work * fix: fix imports * feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats * fix: fix linting errors * test: skip test on end of life model * fix(cohere/embed.py): fix linting error * fix(cohere/embed.py): fix typing * fix(cohere/embed.py): fix post-call logging for cohere embedding call * test(test_embeddings.py): fix error message assertion in test
This commit is contained in:
parent
6fb82aaf75
commit
37f9705d6e
21 changed files with 1946 additions and 1659 deletions
|
@ -1,12 +1,12 @@
|
||||||
repos:
|
repos:
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
# - id: mypy
|
||||||
name: mypy
|
# name: mypy
|
||||||
entry: python3 -m mypy --ignore-missing-imports
|
# entry: python3 -m mypy --ignore-missing-imports
|
||||||
language: system
|
# language: system
|
||||||
types: [python]
|
# types: [python]
|
||||||
files: ^litellm/
|
# files: ^litellm/
|
||||||
- id: isort
|
- id: isort
|
||||||
name: isort
|
name: isort
|
||||||
entry: isort
|
entry: isort
|
||||||
|
|
|
@ -118,6 +118,8 @@ in_memory_llm_clients_cache: dict = {}
|
||||||
safe_memory_mode: bool = False
|
safe_memory_mode: bool = False
|
||||||
### DEFAULT AZURE API VERSION ###
|
### DEFAULT AZURE API VERSION ###
|
||||||
AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest
|
AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest
|
||||||
|
### COHERE EMBEDDINGS DEFAULT TYPE ###
|
||||||
|
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE = "search_document"
|
||||||
### GUARDRAILS ###
|
### GUARDRAILS ###
|
||||||
llamaguard_model_name: Optional[str] = None
|
llamaguard_model_name: Optional[str] = None
|
||||||
openai_moderations_model_name: Optional[str] = None
|
openai_moderations_model_name: Optional[str] = None
|
||||||
|
@ -880,13 +882,13 @@ from .llms.sagemaker.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
from .llms.maritalk import MaritTalkConfig
|
||||||
from .llms.bedrock_httpx import (
|
from .llms.bedrock.chat import (
|
||||||
AmazonCohereChatConfig,
|
AmazonCohereChatConfig,
|
||||||
AmazonConverseConfig,
|
AmazonConverseConfig,
|
||||||
BEDROCK_CONVERSE_MODELS,
|
BEDROCK_CONVERSE_MODELS,
|
||||||
bedrock_tool_name_mappings,
|
bedrock_tool_name_mappings,
|
||||||
)
|
)
|
||||||
from .llms.bedrock import (
|
from .llms.bedrock.common_utils import (
|
||||||
AmazonTitanConfig,
|
AmazonTitanConfig,
|
||||||
AmazonAI21Config,
|
AmazonAI21Config,
|
||||||
AmazonAnthropicConfig,
|
AmazonAnthropicConfig,
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,7 @@
|
||||||
# What is this?
|
"""
|
||||||
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
Manages calling Bedrock's `/converse` API + `/invoke` API
|
||||||
## V1 - covers cohere + anthropic claude-3 support
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -28,7 +29,7 @@ import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.caching import DualCache, InMemoryCache
|
from litellm.caching import InMemoryCache
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
@ -39,21 +40,16 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
)
|
)
|
||||||
from litellm.types.llms.bedrock import *
|
from litellm.types.llms.bedrock import *
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionDeltaChunk,
|
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import Choices
|
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
from litellm.types.utils import Message
|
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret
|
||||||
|
|
||||||
from .base import BaseLLM
|
from ..base_aws_llm import BaseAWSLLM
|
||||||
from .base_aws_llm import BaseAWSLLM
|
from ..prompt_templates.factory import (
|
||||||
from .bedrock import BedrockError, ModelResponseIterator, convert_messages_to_prompt
|
|
||||||
from .prompt_templates.factory import (
|
|
||||||
_bedrock_converse_messages_pt,
|
_bedrock_converse_messages_pt,
|
||||||
_bedrock_tools_pt,
|
_bedrock_tools_pt,
|
||||||
cohere_message_pt,
|
cohere_message_pt,
|
||||||
|
@ -64,6 +60,7 @@ from .prompt_templates.factory import (
|
||||||
parse_xml_params,
|
parse_xml_params,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
|
from .common_utils import BedrockError, ModelResponseIterator, get_runtime_endpoint
|
||||||
|
|
||||||
BEDROCK_CONVERSE_MODELS = [
|
BEDROCK_CONVERSE_MODELS = [
|
||||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
@ -727,22 +724,13 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
### SET RUNTIME ENDPOINT ###
|
||||||
endpoint_url = ""
|
endpoint_url = get_runtime_endpoint(
|
||||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
api_base=api_base,
|
||||||
if api_base is not None:
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
endpoint_url = api_base
|
aws_region_name=aws_region_name,
|
||||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
)
|
||||||
aws_bedrock_runtime_endpoint, str
|
|
||||||
):
|
|
||||||
endpoint_url = aws_bedrock_runtime_endpoint
|
|
||||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
|
||||||
env_aws_bedrock_runtime_endpoint, str
|
|
||||||
):
|
|
||||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
|
||||||
else:
|
|
||||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
|
||||||
|
|
||||||
if (stream is not None and stream == True) and provider != "ai21":
|
if (stream is not None and stream is True) and provider != "ai21":
|
||||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
||||||
else:
|
else:
|
||||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||||
|
@ -1561,21 +1549,11 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
### SET RUNTIME ENDPOINT ###
|
||||||
endpoint_url = ""
|
endpoint_url = get_runtime_endpoint(
|
||||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
api_base=api_base,
|
||||||
if api_base is not None:
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
endpoint_url = api_base
|
aws_region_name=aws_region_name,
|
||||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
)
|
||||||
aws_bedrock_runtime_endpoint, str
|
|
||||||
):
|
|
||||||
endpoint_url = aws_bedrock_runtime_endpoint
|
|
||||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
|
||||||
env_aws_bedrock_runtime_endpoint, str
|
|
||||||
):
|
|
||||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
|
||||||
else:
|
|
||||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
|
||||||
|
|
||||||
if (stream is not None and stream is True) and provider != "ai21":
|
if (stream is not None and stream is True) and provider != "ai21":
|
||||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||||
else:
|
else:
|
773
litellm/llms/bedrock/common_utils.py
Normal file
773
litellm/llms/bedrock/common_utils.py
Normal file
|
@ -0,0 +1,773 @@
|
||||||
|
"""
|
||||||
|
Common utilities used across bedrock chat/embedding/image generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import types
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import get_secret
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonBedrockGlobalConfig:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
"""
|
||||||
|
Mapping of common auth params across bedrock/vertex/azure/watsonx
|
||||||
|
"""
|
||||||
|
return {"region_name": "aws_region_name"}
|
||||||
|
|
||||||
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
mapped_params = self.get_mapped_special_auth_params()
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in mapped_params:
|
||||||
|
optional_params[mapped_params[param]] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_eu_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://www.aws-services.info/bedrock.html
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"eu-west-1",
|
||||||
|
"eu-west-3",
|
||||||
|
"eu-central-1",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||||
|
|
||||||
|
Supported Params for the Amazon Titan models:
|
||||||
|
|
||||||
|
- `maxTokenCount` (integer) max tokens,
|
||||||
|
- `stopSequences` (string[]) list of stop sequence strings
|
||||||
|
- `temperature` (float) temperature for model,
|
||||||
|
- `topP` (int) top p for model
|
||||||
|
"""
|
||||||
|
|
||||||
|
maxTokenCount: Optional[int] = None
|
||||||
|
stopSequences: Optional[list] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
topP: Optional[int] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
maxTokenCount: Optional[int] = None,
|
||||||
|
stopSequences: Optional[list] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
topP: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonAnthropicClaude3Config:
|
||||||
|
"""
|
||||||
|
Reference:
|
||||||
|
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||||
|
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
||||||
|
|
||||||
|
Supported Params for the Amazon / Anthropic Claude 3 models:
|
||||||
|
|
||||||
|
- `max_tokens` Required (integer) max tokens. Default is 4096
|
||||||
|
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||||
|
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
||||||
|
- `temperature` Optional (float) The amount of randomness injected into the response
|
||||||
|
- `top_p` Optional (float) Use nucleus sampling.
|
||||||
|
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
||||||
|
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
|
||||||
|
anthropic_version: Optional[str] = "bedrock-2023-05-31"
|
||||||
|
system: Optional[str] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
stop_sequences: Optional[List[str]] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
anthropic_version: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"max_tokens",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"stream",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"extra_headers",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "tools":
|
||||||
|
optional_params["tools"] = value
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonAnthropicConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||||
|
|
||||||
|
Supported Params for the Amazon / Anthropic models:
|
||||||
|
|
||||||
|
- `max_tokens_to_sample` (integer) max tokens,
|
||||||
|
- `temperature` (float) model temperature,
|
||||||
|
- `top_k` (integer) top k,
|
||||||
|
- `top_p` (integer) top p,
|
||||||
|
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
||||||
|
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_tokens_to_sample: Optional[int] = litellm.max_tokens
|
||||||
|
stop_sequences: Optional[list] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
anthropic_version: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tokens_to_sample: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[list] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
anthropic_version: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens_to_sample"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "stream" and value == True:
|
||||||
|
optional_params["stream"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonCohereConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||||
|
|
||||||
|
Supported Params for the Amazon / Cohere models:
|
||||||
|
|
||||||
|
- `max_tokens` (integer) max tokens,
|
||||||
|
- `temperature` (float) model temperature,
|
||||||
|
- `return_likelihood` (string) n/a
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
return_likelihood: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
return_likelihood: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonAI21Config:
|
||||||
|
"""
|
||||||
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||||
|
|
||||||
|
Supported Params for the Amazon / AI21 models:
|
||||||
|
|
||||||
|
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
||||||
|
|
||||||
|
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
||||||
|
|
||||||
|
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
||||||
|
|
||||||
|
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
||||||
|
|
||||||
|
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
||||||
|
|
||||||
|
- `presencePenalty` (object): Placeholder for presence penalty object.
|
||||||
|
|
||||||
|
- `countPenalty` (object): Placeholder for count penalty object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
maxTokens: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
topP: Optional[float] = None
|
||||||
|
stopSequences: Optional[list] = None
|
||||||
|
frequencePenalty: Optional[dict] = None
|
||||||
|
presencePenalty: Optional[dict] = None
|
||||||
|
countPenalty: Optional[dict] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
maxTokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
topP: Optional[float] = None,
|
||||||
|
stopSequences: Optional[list] = None,
|
||||||
|
frequencePenalty: Optional[dict] = None,
|
||||||
|
presencePenalty: Optional[dict] = None,
|
||||||
|
countPenalty: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicConstants(Enum):
|
||||||
|
HUMAN_PROMPT = "\n\nHuman: "
|
||||||
|
AI_PROMPT = "\n\nAssistant: "
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonLlamaConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
||||||
|
|
||||||
|
Supported Params for the Amazon / Meta Llama models:
|
||||||
|
|
||||||
|
- `max_gen_len` (integer) max tokens,
|
||||||
|
- `temperature` (float) temperature for model,
|
||||||
|
- `top_p` (float) top p for model
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_gen_len: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
topP: Optional[float] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
maxTokenCount: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
topP: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonMistralConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
||||||
|
Supported Params for the Amazon / Mistral models:
|
||||||
|
|
||||||
|
- `max_tokens` (integer) max tokens,
|
||||||
|
- `temperature` (float) temperature for model,
|
||||||
|
- `top_p` (float) top p for model
|
||||||
|
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
|
||||||
|
- `top_k` (float) top k for model
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
top_k: Optional[float] = None
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
top_k: Optional[float] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonStabilityConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||||
|
|
||||||
|
Supported Params for the Amazon / Stable Diffusion models:
|
||||||
|
|
||||||
|
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||||
|
|
||||||
|
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||||
|
|
||||||
|
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||||
|
|
||||||
|
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||||
|
Engine-specific dimension validation:
|
||||||
|
|
||||||
|
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||||
|
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||||
|
- SDXL v1.0: same as SDXL v0.9
|
||||||
|
- SD v1.6: must be between 320x320 and 1536x1536
|
||||||
|
|
||||||
|
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||||
|
Engine-specific dimension validation:
|
||||||
|
|
||||||
|
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||||
|
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||||
|
- SDXL v1.0: same as SDXL v0.9
|
||||||
|
- SD v1.6: must be between 320x320 and 1536x1536
|
||||||
|
"""
|
||||||
|
|
||||||
|
cfg_scale: Optional[int] = None
|
||||||
|
seed: Optional[float] = None
|
||||||
|
steps: Optional[List[str]] = None
|
||||||
|
width: Optional[int] = None
|
||||||
|
height: Optional[int] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg_scale: Optional[int] = None,
|
||||||
|
seed: Optional[float] = None,
|
||||||
|
steps: Optional[List[str]] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_custom_header(headers):
|
||||||
|
"""Closure to capture the headers and add them."""
|
||||||
|
|
||||||
|
def callback(request, **kwargs):
|
||||||
|
"""Actual callback function that Boto3 will call."""
|
||||||
|
for header_name, header_value in headers.items():
|
||||||
|
request.headers.add_header(header_name, header_value)
|
||||||
|
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
|
def init_bedrock_client(
|
||||||
|
region_name=None,
|
||||||
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_region_name: Optional[str] = None,
|
||||||
|
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
||||||
|
aws_session_name: Optional[str] = None,
|
||||||
|
aws_profile_name: Optional[str] = None,
|
||||||
|
aws_role_name: Optional[str] = None,
|
||||||
|
aws_web_identity_token: Optional[str] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
):
|
||||||
|
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||||
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
## CHECK IS 'os.environ/' passed in
|
||||||
|
# Define the list of parameters to check
|
||||||
|
params_to_check = [
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_bedrock_runtime_endpoint,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Iterate over parameters and update if needed
|
||||||
|
for i, param in enumerate(params_to_check):
|
||||||
|
if param and param.startswith("os.environ/"):
|
||||||
|
params_to_check[i] = get_secret(param)
|
||||||
|
# Assign updated values back to parameters
|
||||||
|
(
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_bedrock_runtime_endpoint,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
|
) = params_to_check
|
||||||
|
|
||||||
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||||
|
|
||||||
|
### SET REGION NAME
|
||||||
|
if region_name:
|
||||||
|
pass
|
||||||
|
elif aws_region_name:
|
||||||
|
region_name = aws_region_name
|
||||||
|
elif litellm_aws_region_name:
|
||||||
|
region_name = litellm_aws_region_name
|
||||||
|
elif standard_aws_region_name:
|
||||||
|
region_name = standard_aws_region_name
|
||||||
|
else:
|
||||||
|
raise BedrockError(
|
||||||
|
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
|
||||||
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||||
|
if aws_bedrock_runtime_endpoint:
|
||||||
|
endpoint_url = aws_bedrock_runtime_endpoint
|
||||||
|
elif env_aws_bedrock_runtime_endpoint:
|
||||||
|
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||||
|
else:
|
||||||
|
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
if isinstance(timeout, float):
|
||||||
|
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||||
|
elif isinstance(timeout, httpx.Timeout):
|
||||||
|
config = boto3.session.Config(
|
||||||
|
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = boto3.session.Config()
|
||||||
|
|
||||||
|
### CHECK STS ###
|
||||||
|
if (
|
||||||
|
aws_web_identity_token is not None
|
||||||
|
and aws_role_name is not None
|
||||||
|
and aws_session_name is not None
|
||||||
|
):
|
||||||
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise BedrockError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client("sts")
|
||||||
|
|
||||||
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||||
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||||
|
region_name=region_name,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
config=config,
|
||||||
|
verify=ssl_verify,
|
||||||
|
)
|
||||||
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
|
# use sts if role name passed in
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_response = sts_client.assume_role(
|
||||||
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||||
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||||
|
region_name=region_name,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
config=config,
|
||||||
|
verify=ssl_verify,
|
||||||
|
)
|
||||||
|
elif aws_access_key_id is not None:
|
||||||
|
# uses auth params passed to completion
|
||||||
|
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
region_name=region_name,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
config=config,
|
||||||
|
verify=ssl_verify,
|
||||||
|
)
|
||||||
|
elif aws_profile_name is not None:
|
||||||
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||||
|
|
||||||
|
client = boto3.Session(profile_name=aws_profile_name).client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
region_name=region_name,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
config=config,
|
||||||
|
verify=ssl_verify,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||||
|
# boto3 automatically reads env variables
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
region_name=region_name,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
config=config,
|
||||||
|
verify=ssl_verify,
|
||||||
|
)
|
||||||
|
if extra_headers:
|
||||||
|
client.meta.events.register(
|
||||||
|
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
|
||||||
|
)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def get_runtime_endpoint(
|
||||||
|
api_base: Optional[str],
|
||||||
|
aws_bedrock_runtime_endpoint: Optional[str],
|
||||||
|
aws_region_name: str,
|
||||||
|
) -> str:
|
||||||
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||||
|
if api_base is not None:
|
||||||
|
endpoint_url = api_base
|
||||||
|
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||||
|
aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = aws_bedrock_runtime_endpoint
|
||||||
|
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||||
|
env_aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||||
|
else:
|
||||||
|
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||||
|
|
||||||
|
return endpoint_url
|
||||||
|
|
||||||
|
|
||||||
|
class ModelResponseIterator:
|
||||||
|
def __init__(self, model_response):
|
||||||
|
self.model_response = model_response
|
||||||
|
self.is_done = False
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self.model_response
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self.model_response
|
149
litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py
Normal file
149
litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
"""
|
||||||
|
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
|
||||||
|
Convers
|
||||||
|
- G1 request format
|
||||||
|
|
||||||
|
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from litellm.types.llms.bedrock import (
|
||||||
|
AmazonTitanG1EmbeddingRequest,
|
||||||
|
AmazonTitanG1EmbeddingResponse,
|
||||||
|
AmazonTitanV2EmbeddingRequest,
|
||||||
|
AmazonTitanV2EmbeddingResponse,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanG1Config:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def _transform_request(
|
||||||
|
self, input: str, inference_params: dict
|
||||||
|
) -> AmazonTitanG1EmbeddingRequest:
|
||||||
|
return AmazonTitanG1EmbeddingRequest(inputText=input)
|
||||||
|
|
||||||
|
def _transform_response(
|
||||||
|
self, response_list: List[dict], model: str
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
|
||||||
|
transformed_responses: List[Embedding] = []
|
||||||
|
for index, response in enumerate(response_list):
|
||||||
|
_parsed_response = AmazonTitanG1EmbeddingResponse(**response) # type: ignore
|
||||||
|
transformed_responses.append(
|
||||||
|
Embedding(
|
||||||
|
embedding=_parsed_response["embedding"],
|
||||||
|
index=index,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||||
|
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=total_prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=total_prompt_tokens,
|
||||||
|
)
|
||||||
|
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanV2Config:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||||
|
|
||||||
|
normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true
|
||||||
|
dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256.
|
||||||
|
"""
|
||||||
|
|
||||||
|
normalize: Optional[bool] = None
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def _transform_request(
|
||||||
|
self, input: str, inference_params: dict
|
||||||
|
) -> AmazonTitanV2EmbeddingRequest:
|
||||||
|
return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore
|
||||||
|
|
||||||
|
def _transform_response(
|
||||||
|
self, response_list: List[dict], model: str
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
|
||||||
|
transformed_responses: List[Embedding] = []
|
||||||
|
for index, response in enumerate(response_list):
|
||||||
|
_parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore
|
||||||
|
transformed_responses.append(
|
||||||
|
Embedding(
|
||||||
|
embedding=_parsed_response["embedding"],
|
||||||
|
index=index,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||||
|
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=total_prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=total_prompt_tokens,
|
||||||
|
)
|
||||||
|
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""
|
||||||
|
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan multimodal /invoke format.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
|
||||||
|
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-mm.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from litellm.types.llms.bedrock import (
|
||||||
|
AmazonTitanMultimodalEmbeddingConfig,
|
||||||
|
AmazonTitanMultimodalEmbeddingRequest,
|
||||||
|
AmazonTitanMultimodalEmbeddingResponse,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||||
|
from litellm.utils import is_base64_encoded
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_request(
|
||||||
|
input: str, inference_params: dict
|
||||||
|
) -> AmazonTitanMultimodalEmbeddingRequest:
|
||||||
|
## check if b64 encoded str or not ##
|
||||||
|
is_encoded = is_base64_encoded(input)
|
||||||
|
if is_encoded: # check if string is b64 encoded image or not
|
||||||
|
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputImage=input)
|
||||||
|
else:
|
||||||
|
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputText=input)
|
||||||
|
|
||||||
|
for k, v in inference_params.items():
|
||||||
|
transformed_request[k] = v # type: ignore
|
||||||
|
|
||||||
|
return transformed_request
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_response(response_list: List[dict], model: str) -> EmbeddingResponse:
|
||||||
|
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
transformed_responses: List[Embedding] = []
|
||||||
|
for index, response in enumerate(response_list):
|
||||||
|
_parsed_response = AmazonTitanMultimodalEmbeddingResponse(**response) # type: ignore
|
||||||
|
transformed_responses.append(
|
||||||
|
Embedding(
|
||||||
|
embedding=_parsed_response["embedding"], index=index, object="embedding"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||||
|
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=total_prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=total_prompt_tokens,
|
||||||
|
)
|
||||||
|
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
86
litellm/llms/bedrock/embed/amazon_titan_v2_transformation.py
Normal file
86
litellm/llms/bedrock/embed/amazon_titan_v2_transformation.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
"""
|
||||||
|
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan V2 /invoke format.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
|
||||||
|
Convers
|
||||||
|
- v2 request format
|
||||||
|
|
||||||
|
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from litellm.types.llms.bedrock import (
|
||||||
|
AmazonTitanV2EmbeddingRequest,
|
||||||
|
AmazonTitanV2EmbeddingResponse,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanV2Config:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||||
|
|
||||||
|
normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true
|
||||||
|
dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256.
|
||||||
|
"""
|
||||||
|
|
||||||
|
normalize: Optional[bool] = None
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def _transform_request(
|
||||||
|
self, input: str, inference_params: dict
|
||||||
|
) -> AmazonTitanV2EmbeddingRequest:
|
||||||
|
return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore
|
||||||
|
|
||||||
|
def _transform_response(
|
||||||
|
self, response_list: List[dict], model: str
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
|
||||||
|
transformed_responses: List[Embedding] = []
|
||||||
|
for index, response in enumerate(response_list):
|
||||||
|
_parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore
|
||||||
|
transformed_responses.append(
|
||||||
|
Embedding(
|
||||||
|
embedding=_parsed_response["embedding"],
|
||||||
|
index=index,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||||
|
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=total_prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=total_prompt_tokens,
|
||||||
|
)
|
||||||
|
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
25
litellm/llms/bedrock/embed/cohere_transformation.py
Normal file
25
litellm/llms/bedrock/embed/cohere_transformation.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
"""
|
||||||
|
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.types.llms.bedrock import CohereEmbeddingRequest, CohereEmbeddingResponse
|
||||||
|
from litellm.types.utils import Embedding, EmbeddingResponse
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_request(
|
||||||
|
input: List[str], inference_params: dict
|
||||||
|
) -> CohereEmbeddingRequest:
|
||||||
|
transformed_request = CohereEmbeddingRequest(
|
||||||
|
texts=input,
|
||||||
|
input_type=litellm.COHERE_DEFAULT_EMBEDDING_INPUT_TYPE, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
for k, v in inference_params.items():
|
||||||
|
transformed_request[k] = v # type: ignore
|
||||||
|
|
||||||
|
return transformed_request
|
498
litellm/llms/bedrock/embed/embedding.py
Normal file
498
litellm/llms/bedrock/embed/embedding.py
Normal file
|
@ -0,0 +1,498 @@
|
||||||
|
"""
|
||||||
|
Handles embedding calls to Bedrock's `/invoke` endpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import get_secret
|
||||||
|
from litellm.llms.cohere.embed import embedding as cohere_embedding
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
_get_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
|
||||||
|
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||||
|
|
||||||
|
from ...base_aws_llm import BaseAWSLLM
|
||||||
|
from ..common_utils import BedrockError, get_runtime_endpoint
|
||||||
|
from .amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||||
|
from .amazon_titan_multimodal_transformation import (
|
||||||
|
_transform_request as amazon_multimodal_transform_request,
|
||||||
|
)
|
||||||
|
from .amazon_titan_multimodal_transformation import (
|
||||||
|
_transform_response as amazon_multimodal_transform_response,
|
||||||
|
)
|
||||||
|
from .amazon_titan_v2_transformation import AmazonTitanV2Config
|
||||||
|
from .cohere_transformation import _transform_request as cohere_transform_request
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockEmbedding(BaseAWSLLM):
|
||||||
|
def _load_credentials(
|
||||||
|
self,
|
||||||
|
optional_params: dict,
|
||||||
|
) -> Tuple[Any, str]:
|
||||||
|
try:
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
## CREDENTIALS ##
|
||||||
|
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||||
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
|
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||||
|
|
||||||
|
### SET REGION NAME ###
|
||||||
|
if aws_region_name is None:
|
||||||
|
# check env #
|
||||||
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
|
||||||
|
if litellm_aws_region_name is not None and isinstance(
|
||||||
|
litellm_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = litellm_aws_region_name
|
||||||
|
|
||||||
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
if standard_aws_region_name is not None and isinstance(
|
||||||
|
standard_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = standard_aws_region_name
|
||||||
|
|
||||||
|
if aws_region_name is None:
|
||||||
|
aws_region_name = "us-west-2"
|
||||||
|
|
||||||
|
credentials: Credentials = self.get_credentials(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
aws_sts_endpoint=aws_sts_endpoint,
|
||||||
|
)
|
||||||
|
return credentials, aws_region_name
|
||||||
|
|
||||||
|
async def async_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _make_sync_call(
|
||||||
|
self,
|
||||||
|
client: Optional[HTTPHandler],
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: dict,
|
||||||
|
) -> dict:
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = _get_httpx_client(_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client
|
||||||
|
try:
|
||||||
|
response = client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def _single_func_embeddings(
|
||||||
|
self,
|
||||||
|
client: Optional[HTTPHandler],
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
batch_data: List[dict],
|
||||||
|
credentials: Any,
|
||||||
|
extra_headers: Optional[dict],
|
||||||
|
endpoint_url: str,
|
||||||
|
aws_region_name: str,
|
||||||
|
model: str,
|
||||||
|
logging_obj: Any,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
|
responses: List[dict] = []
|
||||||
|
for data in batch_data:
|
||||||
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
|
||||||
|
)
|
||||||
|
sigv4.add_auth(request)
|
||||||
|
if (
|
||||||
|
extra_headers is not None and "Authorization" in extra_headers
|
||||||
|
): # prevent sigv4 from overwriting the auth header
|
||||||
|
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=data,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": prepped.url,
|
||||||
|
"headers": prepped.headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response = self._make_sync_call(
|
||||||
|
client=client,
|
||||||
|
timeout=timeout,
|
||||||
|
api_base=prepped.url,
|
||||||
|
headers=prepped.headers,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=data,
|
||||||
|
api_key="",
|
||||||
|
original_response=response,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
returned_response: Optional[EmbeddingResponse] = None
|
||||||
|
|
||||||
|
## TRANSFORM RESPONSE ##
|
||||||
|
if model == "amazon.titan-embed-image-v1":
|
||||||
|
returned_response = amazon_multimodal_transform_response(
|
||||||
|
response_list=responses, model=model
|
||||||
|
)
|
||||||
|
elif model == "amazon.titan-embed-text-v1":
|
||||||
|
returned_response = AmazonTitanG1Config()._transform_response(
|
||||||
|
response_list=responses, model=model
|
||||||
|
)
|
||||||
|
elif model == "amazon.titan-embed-text-v2:0":
|
||||||
|
returned_response = AmazonTitanV2Config()._transform_response(
|
||||||
|
response_list=responses, model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
if returned_response is None:
|
||||||
|
raise Exception(
|
||||||
|
"Unable to map model response to known provider format. model={}".format(
|
||||||
|
model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return returned_response
|
||||||
|
|
||||||
|
def embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: List[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]],
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
aembedding: Optional[bool],
|
||||||
|
extra_headers: Optional[dict],
|
||||||
|
optional_params=None,
|
||||||
|
litellm_params=None,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
|
credentials, aws_region_name = self._load_credentials(optional_params)
|
||||||
|
|
||||||
|
### TRANSFORMATION ###
|
||||||
|
provider = model.split(".")[0]
|
||||||
|
inference_params = copy.deepcopy(optional_params)
|
||||||
|
inference_params.pop(
|
||||||
|
"user", None
|
||||||
|
) # make sure user is not passed in for bedrock call
|
||||||
|
modelId = (
|
||||||
|
optional_params.pop("model_id", None) or model
|
||||||
|
) # default to model if not passed
|
||||||
|
|
||||||
|
data: Optional[CohereEmbeddingRequest] = None
|
||||||
|
batch_data: Optional[List] = None
|
||||||
|
if provider == "cohere":
|
||||||
|
data = cohere_transform_request(
|
||||||
|
input=input, inference_params=inference_params
|
||||||
|
)
|
||||||
|
elif provider == "amazon" and model in [
|
||||||
|
"amazon.titan-embed-image-v1",
|
||||||
|
"amazon.titan-embed-text-v1",
|
||||||
|
"amazon.titan-embed-text-v2:0",
|
||||||
|
]:
|
||||||
|
batch_data = []
|
||||||
|
for i in input:
|
||||||
|
if model == "amazon.titan-embed-image-v1":
|
||||||
|
transformed_request: AmazonEmbeddingRequest = (
|
||||||
|
amazon_multimodal_transform_request(
|
||||||
|
input=i, inference_params=inference_params
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif model == "amazon.titan-embed-text-v1":
|
||||||
|
transformed_request = AmazonTitanG1Config()._transform_request(
|
||||||
|
input=i, inference_params=inference_params
|
||||||
|
)
|
||||||
|
elif model == "amazon.titan-embed-text-v2:0":
|
||||||
|
transformed_request = AmazonTitanV2Config()._transform_request(
|
||||||
|
input=i, inference_params=inference_params
|
||||||
|
)
|
||||||
|
batch_data.append(transformed_request)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
endpoint_url = get_runtime_endpoint(
|
||||||
|
api_base=api_base,
|
||||||
|
aws_bedrock_runtime_endpoint=optional_params.pop(
|
||||||
|
"aws_bedrock_runtime_endpoint", None
|
||||||
|
),
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||||
|
|
||||||
|
if batch_data is not None:
|
||||||
|
return self._single_func_embeddings(
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, HTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
timeout=timeout,
|
||||||
|
batch_data=batch_data,
|
||||||
|
credentials=credentials,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
model=model,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
elif data is None:
|
||||||
|
raise Exception("Unable to map request to provider")
|
||||||
|
|
||||||
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
|
||||||
|
)
|
||||||
|
sigv4.add_auth(request)
|
||||||
|
if (
|
||||||
|
extra_headers is not None and "Authorization" in extra_headers
|
||||||
|
): # prevent sigv4 from overwriting the auth header
|
||||||
|
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
## ROUTING ##
|
||||||
|
return cohere_embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
data=data, # type: ignore
|
||||||
|
complete_api_base=prepped.url,
|
||||||
|
api_key=None,
|
||||||
|
aembedding=aembedding,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
headers=prepped.headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# def _embedding_func_single(
|
||||||
|
# model: str,
|
||||||
|
# input: str,
|
||||||
|
# client: Any,
|
||||||
|
# optional_params=None,
|
||||||
|
# encoding=None,
|
||||||
|
# logging_obj=None,
|
||||||
|
# ):
|
||||||
|
# if isinstance(input, str) is False:
|
||||||
|
# raise BedrockError(
|
||||||
|
# message="Bedrock Embedding API input must be type str | List[str]",
|
||||||
|
# status_code=400,
|
||||||
|
# )
|
||||||
|
# # logic for parsing in - calling - parsing out model embedding calls
|
||||||
|
# ## FORMAT EMBEDDING INPUT ##
|
||||||
|
# provider = model.split(".")[0]
|
||||||
|
# inference_params = copy.deepcopy(optional_params)
|
||||||
|
# inference_params.pop(
|
||||||
|
# "user", None
|
||||||
|
# ) # make sure user is not passed in for bedrock call
|
||||||
|
# modelId = (
|
||||||
|
# optional_params.pop("model_id", None) or model
|
||||||
|
# ) # default to model if not passed
|
||||||
|
# if provider == "amazon":
|
||||||
|
# input = input.replace(os.linesep, " ")
|
||||||
|
# data = {"inputText": input, **inference_params}
|
||||||
|
# # data = json.dumps(data)
|
||||||
|
# elif provider == "cohere":
|
||||||
|
# inference_params["input_type"] = inference_params.get(
|
||||||
|
# "input_type", "search_document"
|
||||||
|
# ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||||
|
# data = {"texts": [input], **inference_params} # type: ignore
|
||||||
|
# body = json.dumps(data).encode("utf-8") # type: ignore
|
||||||
|
# ## LOGGING
|
||||||
|
# request_str = f"""
|
||||||
|
# response = client.invoke_model(
|
||||||
|
# body={body},
|
||||||
|
# modelId={modelId},
|
||||||
|
# accept="*/*",
|
||||||
|
# contentType="application/json",
|
||||||
|
# )""" # type: ignore
|
||||||
|
# logging_obj.pre_call(
|
||||||
|
# input=input,
|
||||||
|
# api_key="", # boto3 is used for init.
|
||||||
|
# additional_args={
|
||||||
|
# "complete_input_dict": {"model": modelId, "texts": input},
|
||||||
|
# "request_str": request_str,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# try:
|
||||||
|
# response = client.invoke_model(
|
||||||
|
# body=body,
|
||||||
|
# modelId=modelId,
|
||||||
|
# accept="*/*",
|
||||||
|
# contentType="application/json",
|
||||||
|
# )
|
||||||
|
# response_body = json.loads(response.get("body").read())
|
||||||
|
# ## LOGGING
|
||||||
|
# logging_obj.post_call(
|
||||||
|
# input=input,
|
||||||
|
# api_key="",
|
||||||
|
# additional_args={"complete_input_dict": data},
|
||||||
|
# original_response=json.dumps(response_body),
|
||||||
|
# )
|
||||||
|
# if provider == "cohere":
|
||||||
|
# response = response_body.get("embeddings")
|
||||||
|
# # flatten list
|
||||||
|
# response = [item for sublist in response for item in sublist]
|
||||||
|
# return response
|
||||||
|
# elif provider == "amazon":
|
||||||
|
# return response_body.get("embedding")
|
||||||
|
# except Exception as e:
|
||||||
|
# raise BedrockError(
|
||||||
|
# message=f"Embedding Error with model {model}: {e}", status_code=500
|
||||||
|
# )
|
||||||
|
|
||||||
|
# def embedding(
|
||||||
|
# model: str,
|
||||||
|
# input: Union[list, str],
|
||||||
|
# model_response: litellm.EmbeddingResponse,
|
||||||
|
# api_key: Optional[str] = None,
|
||||||
|
# logging_obj=None,
|
||||||
|
# optional_params=None,
|
||||||
|
# encoding=None,
|
||||||
|
# ):
|
||||||
|
# ### BOTO3 INIT ###
|
||||||
|
# # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
# aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
# aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
# aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
# aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
|
# aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
|
# aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
|
# "aws_bedrock_runtime_endpoint", None
|
||||||
|
# )
|
||||||
|
# aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
|
# # use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
|
# client = init_bedrock_client(
|
||||||
|
# aws_access_key_id=aws_access_key_id,
|
||||||
|
# aws_secret_access_key=aws_secret_access_key,
|
||||||
|
# aws_region_name=aws_region_name,
|
||||||
|
# aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
# aws_web_identity_token=aws_web_identity_token,
|
||||||
|
# aws_role_name=aws_role_name,
|
||||||
|
# aws_session_name=aws_session_name,
|
||||||
|
# )
|
||||||
|
# if isinstance(input, str):
|
||||||
|
# ## Embedding Call
|
||||||
|
# embeddings = [
|
||||||
|
# _embedding_func_single(
|
||||||
|
# model,
|
||||||
|
# input,
|
||||||
|
# optional_params=optional_params,
|
||||||
|
# client=client,
|
||||||
|
# logging_obj=logging_obj,
|
||||||
|
# )
|
||||||
|
# ]
|
||||||
|
# elif isinstance(input, list):
|
||||||
|
# ## Embedding Call - assuming this is a List[str]
|
||||||
|
# embeddings = [
|
||||||
|
# _embedding_func_single(
|
||||||
|
# model,
|
||||||
|
# i,
|
||||||
|
# optional_params=optional_params,
|
||||||
|
# client=client,
|
||||||
|
# logging_obj=logging_obj,
|
||||||
|
# )
|
||||||
|
# for i in input
|
||||||
|
# ] # [TODO]: make these parallel calls
|
||||||
|
# else:
|
||||||
|
# # enters this branch if input = int, ex. input=2
|
||||||
|
# raise BedrockError(
|
||||||
|
# message="Bedrock Embedding API input must be type str | List[str]",
|
||||||
|
# status_code=400,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# ## Populate OpenAI compliant dictionary
|
||||||
|
# embedding_response = []
|
||||||
|
# for idx, embedding in enumerate(embeddings):
|
||||||
|
# embedding_response.append(
|
||||||
|
# {
|
||||||
|
# "object": "embedding",
|
||||||
|
# "index": idx,
|
||||||
|
# "embedding": embedding,
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
# model_response.object = "list"
|
||||||
|
# model_response.data = embedding_response
|
||||||
|
# model_response.model = model
|
||||||
|
# input_tokens = 0
|
||||||
|
|
||||||
|
# input_str = "".join(input)
|
||||||
|
|
||||||
|
# input_tokens += len(encoding.encode(input_str))
|
||||||
|
|
||||||
|
# usage = Usage(
|
||||||
|
# prompt_tokens=input_tokens,
|
||||||
|
# completion_tokens=0,
|
||||||
|
# total_tokens=input_tokens + 0,
|
||||||
|
# )
|
||||||
|
# model_response.usage = usage
|
||||||
|
|
||||||
|
# return model_response
|
127
litellm/llms/bedrock/image_generation.py
Normal file
127
litellm/llms/bedrock/image_generation.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
"""
|
||||||
|
Handles image gen calls to Bedrock's `/invoke` endpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from openai.types.image import Image
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.types.utils import ImageResponse
|
||||||
|
|
||||||
|
from .common_utils import BedrockError, init_bedrock_client
|
||||||
|
|
||||||
|
|
||||||
|
def image_generation(
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
model_response: ImageResponse,
|
||||||
|
optional_params: dict,
|
||||||
|
timeout=None,
|
||||||
|
logging_obj=None,
|
||||||
|
aimg_generation=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Bedrock Image Gen endpoint support
|
||||||
|
"""
|
||||||
|
### BOTO3 INIT ###
|
||||||
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
|
"aws_bedrock_runtime_endpoint", None
|
||||||
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
|
client = init_bedrock_client(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
### FORMAT IMAGE GENERATION INPUT ###
|
||||||
|
modelId = model
|
||||||
|
provider = model.split(".")[0]
|
||||||
|
inference_params = copy.deepcopy(optional_params)
|
||||||
|
inference_params.pop(
|
||||||
|
"user", None
|
||||||
|
) # make sure user is not passed in for bedrock call
|
||||||
|
data = {}
|
||||||
|
if provider == "stability":
|
||||||
|
prompt = prompt.replace(os.linesep, " ")
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonStabilityConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
|
||||||
|
else:
|
||||||
|
raise BedrockError(
|
||||||
|
status_code=422, message=f"Unsupported model={model}, passed in"
|
||||||
|
)
|
||||||
|
|
||||||
|
body = json.dumps(data).encode("utf-8")
|
||||||
|
## LOGGING
|
||||||
|
request_str = f"""
|
||||||
|
response = client.invoke_model(
|
||||||
|
body={body}, # type: ignore
|
||||||
|
modelId={modelId},
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)""" # type: ignore
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key="", # boto3 is used for init.
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": {"model": modelId, "texts": prompt},
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = client.invoke_model(
|
||||||
|
body=body,
|
||||||
|
modelId=modelId,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key="",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=json.dumps(response_body),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockError(
|
||||||
|
message=f"Embedding Error with model {model}: {e}", status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||||
|
if response_body is None:
|
||||||
|
raise Exception("Error in response object format")
|
||||||
|
|
||||||
|
if model_response is None:
|
||||||
|
model_response = ImageResponse()
|
||||||
|
|
||||||
|
image_list: List[Image] = []
|
||||||
|
for artifact in response_body["artifacts"]:
|
||||||
|
_image = Image(b64_json=artifact["base64"])
|
||||||
|
image_list.append(_image)
|
||||||
|
|
||||||
|
model_response.data = image_list
|
||||||
|
return model_response
|
|
@ -76,7 +76,7 @@ async def async_embedding(
|
||||||
data: dict,
|
data: dict,
|
||||||
input: list,
|
input: list,
|
||||||
model_response: litellm.utils.EmbeddingResponse,
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
|
@ -98,16 +98,35 @@ async def async_embedding(
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if client is None:
|
if client is None:
|
||||||
client = AsyncHTTPHandler(concurrent_limit=1)
|
client = AsyncHTTPHandler(concurrent_limit=1, timeout=timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=e.response.text,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=str(e),
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
original_response=response,
|
original_response=response.text,
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = response.json()["embeddings"]
|
embeddings = response.json()["embeddings"]
|
||||||
|
@ -130,27 +149,22 @@ def embedding(
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
|
data: Optional[dict] = None,
|
||||||
|
complete_api_base: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
aembedding: Optional[bool] = None,
|
aembedding: Optional[bool] = None,
|
||||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None),
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key, headers=headers)
|
headers = validate_environment(api_key, headers=headers)
|
||||||
embed_url = "https://api.cohere.ai/v1/embed"
|
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
|
||||||
model = model
|
model = model
|
||||||
data = {"model": model, "texts": input, **optional_params}
|
data = data or {"model": model, "texts": input, **optional_params}
|
||||||
|
|
||||||
if "3" in model and "input_type" not in data:
|
if "3" in model and "input_type" not in data:
|
||||||
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
|
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
|
||||||
data["input_type"] = "search_document"
|
data["input_type"] = "search_document"
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=input,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
|
|
||||||
## ROUTING
|
## ROUTING
|
||||||
if aembedding is True:
|
if aembedding is True:
|
||||||
return async_embedding(
|
return async_embedding(
|
||||||
|
@ -166,9 +180,18 @@ def embedding(
|
||||||
headers=headers,
|
headers=headers,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
client = HTTPHandler(concurrent_limit=1)
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
|
|
||||||
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
|
|
@ -78,7 +78,6 @@ from .llms import (
|
||||||
ai21,
|
ai21,
|
||||||
aleph_alpha,
|
aleph_alpha,
|
||||||
baseten,
|
baseten,
|
||||||
bedrock,
|
|
||||||
clarifai,
|
clarifai,
|
||||||
cloudflare,
|
cloudflare,
|
||||||
maritalk,
|
maritalk,
|
||||||
|
@ -96,7 +95,9 @@ from .llms.anthropic.chat import AnthropicChatCompletion
|
||||||
from .llms.anthropic.completion import AnthropicTextCompletion
|
from .llms.anthropic.completion import AnthropicTextCompletion
|
||||||
from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params
|
from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||||
from .llms.azure_text import AzureTextCompletion
|
from .llms.azure_text import AzureTextCompletion
|
||||||
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
|
||||||
|
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||||
|
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||||
from .llms.cohere import chat as cohere_chat
|
from .llms.cohere import chat as cohere_chat
|
||||||
from .llms.cohere import completion as cohere_completion # type: ignore
|
from .llms.cohere import completion as cohere_completion # type: ignore
|
||||||
from .llms.cohere import embed as cohere_embed
|
from .llms.cohere import embed as cohere_embed
|
||||||
|
@ -176,6 +177,7 @@ codestral_text_completions = CodestralTextCompletion()
|
||||||
triton_chat_completions = TritonChatCompletion()
|
triton_chat_completions = TritonChatCompletion()
|
||||||
bedrock_chat_completion = BedrockLLM()
|
bedrock_chat_completion = BedrockLLM()
|
||||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
|
bedrock_embedding = BedrockEmbedding()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
vertex_multimodal_embedding = VertexMultimodalEmbedding()
|
vertex_multimodal_embedding = VertexMultimodalEmbedding()
|
||||||
google_batch_embeddings = GoogleBatchEmbeddings()
|
google_batch_embeddings = GoogleBatchEmbeddings()
|
||||||
|
@ -3151,6 +3153,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
||||||
or custom_llm_provider == "watsonx"
|
or custom_llm_provider == "watsonx"
|
||||||
or custom_llm_provider == "cohere"
|
or custom_llm_provider == "cohere"
|
||||||
or custom_llm_provider == "huggingface"
|
or custom_llm_provider == "huggingface"
|
||||||
|
or custom_llm_provider == "bedrock"
|
||||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
# Await normally
|
# Await normally
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -3519,13 +3522,24 @@ def embedding(
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
response = bedrock.embedding(
|
if isinstance(input, str):
|
||||||
|
transformed_input = [input]
|
||||||
|
else:
|
||||||
|
transformed_input = input
|
||||||
|
response = bedrock_embedding.embeddings(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=transformed_input,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
|
client=client,
|
||||||
|
timeout=timeout,
|
||||||
|
aembedding=aembedding,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
api_base=api_base,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
extra_headers=extra_headers,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "triton":
|
elif custom_llm_provider == "triton":
|
||||||
if api_base is None:
|
if api_base is None:
|
||||||
|
@ -4493,7 +4507,7 @@ def image_generation(
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
if model is None:
|
if model is None:
|
||||||
raise Exception("Model needs to be set for bedrock")
|
raise Exception("Model needs to be set for bedrock")
|
||||||
model_response = bedrock.image_generation(
|
model_response = bedrock_image_generation.image_generation(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
|
|
@ -178,7 +178,7 @@ async def bedrock_proxy_route(
|
||||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
# Add or update query parameters
|
# Add or update query parameters
|
||||||
from litellm.llms.bedrock_httpx import BedrockConverseLLM
|
from litellm.llms.bedrock.chat import BedrockConverseLLM
|
||||||
|
|
||||||
credentials: Credentials = BedrockConverseLLM().get_credentials()
|
credentials: Credentials = BedrockConverseLLM().get_credentials()
|
||||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
|
|
@ -25,7 +25,7 @@ from litellm import (
|
||||||
completion_cost,
|
completion_cost,
|
||||||
embedding,
|
embedding,
|
||||||
)
|
)
|
||||||
from litellm.llms.bedrock_httpx import BedrockLLM, ToolBlock
|
from litellm.llms.bedrock.chat import BedrockLLM, ToolBlock
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.prompt_templates.factory import _bedrock_tools_pt
|
from litellm.llms.prompt_templates.factory import _bedrock_tools_pt
|
||||||
|
|
||||||
|
|
|
@ -311,7 +311,17 @@ async def test_cohere_embedding3(custom_llm_provider):
|
||||||
# test_cohere_embedding3()
|
# test_cohere_embedding3()
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_embedding_titan():
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"bedrock/amazon.titan-embed-text-v1",
|
||||||
|
"bedrock/amazon.titan-embed-image-v1",
|
||||||
|
"bedrock/amazon.titan-embed-text-v2:0",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bedrock_embedding_titan(model, sync_mode):
|
||||||
try:
|
try:
|
||||||
# this tests if we support str input for bedrock embedding
|
# this tests if we support str input for bedrock embedding
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -320,16 +330,23 @@ def test_bedrock_embedding_titan():
|
||||||
|
|
||||||
current_time = str(time.time())
|
current_time = str(time.time())
|
||||||
# DO NOT MAKE THE INPUT A LIST in this test
|
# DO NOT MAKE THE INPUT A LIST in this test
|
||||||
|
if sync_mode:
|
||||||
response = embedding(
|
response = embedding(
|
||||||
model="bedrock/amazon.titan-embed-text-v1",
|
model=model,
|
||||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||||
aws_region_name="us-west-2",
|
aws_region_name="us-west-2",
|
||||||
)
|
)
|
||||||
print(f"response:", response)
|
else:
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model=model,
|
||||||
|
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||||
|
aws_region_name="us-west-2",
|
||||||
|
)
|
||||||
|
print("response:", response)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response["data"][0]["embedding"], list
|
response["data"][0]["embedding"], list
|
||||||
), "Expected response to be a list"
|
), "Expected response to be a list"
|
||||||
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
|
print("type of first embedding:", type(response["data"][0]["embedding"][0]))
|
||||||
assert all(
|
assert all(
|
||||||
isinstance(x, float) for x in response["data"][0]["embedding"]
|
isinstance(x, float) for x in response["data"][0]["embedding"]
|
||||||
), "Expected response to be a list of floats"
|
), "Expected response to be a list of floats"
|
||||||
|
@ -339,13 +356,20 @@ def test_bedrock_embedding_titan():
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
response = embedding(
|
response = embedding(
|
||||||
model="bedrock/amazon.titan-embed-text-v1",
|
model=model,
|
||||||
|
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model=model,
|
||||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
print(response._hidden_params)
|
||||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||||
|
|
||||||
assert end_time - start_time < 0.1
|
assert end_time - start_time < 0.1
|
||||||
|
@ -392,13 +416,13 @@ def test_demo_tokens_as_input_to_embeddings_fails_for_titan():
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
litellm.BadRequestError,
|
litellm.BadRequestError,
|
||||||
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
|
match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: JSONArray, please reformat your input and try again."}',
|
||||||
):
|
):
|
||||||
litellm.embedding(model="amazon.titan-embed-text-v1", input=[[1]])
|
litellm.embedding(model="amazon.titan-embed-text-v1", input=[[1]])
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
litellm.BadRequestError,
|
litellm.BadRequestError,
|
||||||
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
|
match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: Integer, please reformat your input and try again."}',
|
||||||
):
|
):
|
||||||
litellm.embedding(
|
litellm.embedding(
|
||||||
model="amazon.titan-embed-text-v1",
|
model="amazon.titan-embed-text-v1",
|
||||||
|
|
|
@ -1,21 +1,25 @@
|
||||||
import sys, os, uuid
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import uuid
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os
|
import os
|
||||||
from uuid import uuid4
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from litellm import get_secret
|
from litellm import get_secret
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
from litellm.llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||||
|
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
@pytest.mark.skip(reason="AWS Suspended Account")
|
||||||
|
@ -63,9 +67,7 @@ def test_oidc_github():
|
||||||
reason="Cannot run without being in CircleCI Runner",
|
reason="Cannot run without being in CircleCI Runner",
|
||||||
)
|
)
|
||||||
def test_oidc_circleci():
|
def test_oidc_circleci():
|
||||||
secret_val = get_secret(
|
secret_val = get_secret("oidc/circleci/")
|
||||||
"oidc/circleci/"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
||||||
|
@ -103,9 +105,7 @@ def test_oidc_circle_v1_with_amazon():
|
||||||
# The purpose of this test is to get logs using the older v1 of the CircleCI OIDC token
|
# The purpose of this test is to get logs using the older v1 of the CircleCI OIDC token
|
||||||
|
|
||||||
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
||||||
aws_role_name = (
|
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
|
||||||
"arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
|
|
||||||
)
|
|
||||||
aws_web_identity_token = "oidc/circleci/"
|
aws_web_identity_token = "oidc/circleci/"
|
||||||
|
|
||||||
bllm = BedrockLLM()
|
bllm = BedrockLLM()
|
||||||
|
@ -116,6 +116,7 @@ def test_oidc_circle_v1_with_amazon():
|
||||||
aws_session_name="assume-v1-session",
|
aws_session_name="assume-v1-session",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
|
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
|
||||||
reason="Cannot run without being in CircleCI Runner",
|
reason="Cannot run without being in CircleCI Runner",
|
||||||
|
@ -124,9 +125,7 @@ def test_oidc_circle_v1_with_amazon_fips():
|
||||||
# The purpose of this test is to validate that we can assume a role in a FIPS region
|
# The purpose of this test is to validate that we can assume a role in a FIPS region
|
||||||
|
|
||||||
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
||||||
aws_role_name = (
|
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
|
||||||
"arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
|
|
||||||
)
|
|
||||||
aws_web_identity_token = "oidc/circleci/"
|
aws_web_identity_token = "oidc/circleci/"
|
||||||
|
|
||||||
bllm = BedrockConverseLLM()
|
bllm = BedrockConverseLLM()
|
||||||
|
@ -143,9 +142,7 @@ def test_oidc_env_variable():
|
||||||
# Create a unique environment variable name
|
# Create a unique environment variable name
|
||||||
env_var_name = "OIDC_TEST_PATH_" + uuid4().hex
|
env_var_name = "OIDC_TEST_PATH_" + uuid4().hex
|
||||||
os.environ[env_var_name] = "secret-" + uuid4().hex
|
os.environ[env_var_name] = "secret-" + uuid4().hex
|
||||||
secret_val = get_secret(
|
secret_val = get_secret(f"oidc/env/{env_var_name}")
|
||||||
f"oidc/env/{env_var_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
||||||
|
@ -157,15 +154,13 @@ def test_oidc_env_variable():
|
||||||
|
|
||||||
def test_oidc_file():
|
def test_oidc_file():
|
||||||
# Create a temporary file
|
# Create a temporary file
|
||||||
with tempfile.NamedTemporaryFile(mode='w+') as temp_file:
|
with tempfile.NamedTemporaryFile(mode="w+") as temp_file:
|
||||||
secret_value = "secret-" + uuid4().hex
|
secret_value = "secret-" + uuid4().hex
|
||||||
temp_file.write(secret_value)
|
temp_file.write(secret_value)
|
||||||
temp_file.flush()
|
temp_file.flush()
|
||||||
temp_file_path = temp_file.name
|
temp_file_path = temp_file.name
|
||||||
|
|
||||||
secret_val = get_secret(
|
secret_val = get_secret(f"oidc/file/{temp_file_path}")
|
||||||
f"oidc/file/{temp_file_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
||||||
|
@ -174,7 +169,7 @@ def test_oidc_file():
|
||||||
|
|
||||||
def test_oidc_env_path():
|
def test_oidc_env_path():
|
||||||
# Create a temporary file
|
# Create a temporary file
|
||||||
with tempfile.NamedTemporaryFile(mode='w+') as temp_file:
|
with tempfile.NamedTemporaryFile(mode="w+") as temp_file:
|
||||||
secret_value = "secret-" + uuid4().hex
|
secret_value = "secret-" + uuid4().hex
|
||||||
temp_file.write(secret_value)
|
temp_file.write(secret_value)
|
||||||
temp_file.flush()
|
temp_file.flush()
|
||||||
|
@ -187,9 +182,7 @@ def test_oidc_env_path():
|
||||||
os.environ[env_var_name] = temp_file_path
|
os.environ[env_var_name] = temp_file_path
|
||||||
|
|
||||||
# Test getting the secret using the environment variable
|
# Test getting the secret using the environment variable
|
||||||
secret_val = get_secret(
|
secret_val = get_secret(f"oidc/env_path/{env_var_name}")
|
||||||
f"oidc/env_path/{env_var_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||||
|
|
||||||
|
|
|
@ -208,3 +208,62 @@ class ServerSentEvent:
|
||||||
@override
|
@override
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
|
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
|
||||||
|
|
||||||
|
|
||||||
|
class CohereEmbeddingRequest(TypedDict, total=False):
|
||||||
|
texts: Required[List[str]]
|
||||||
|
input_type: Required[
|
||||||
|
Literal["search_document", "search_query", "classification", "clustering"]
|
||||||
|
]
|
||||||
|
truncate: Literal["NONE", "START", "END"]
|
||||||
|
embedding_types: Literal["float", "int8", "uint8", "binary", "ubinary"]
|
||||||
|
|
||||||
|
|
||||||
|
class CohereEmbeddingResponse(TypedDict):
|
||||||
|
embeddings: List[List[float]]
|
||||||
|
id: str
|
||||||
|
response_type: Literal["embedding_floats"]
|
||||||
|
texts: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanV2EmbeddingRequest(TypedDict):
|
||||||
|
inputText: str
|
||||||
|
dimensions: int
|
||||||
|
normalize: bool
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanV2EmbeddingResponse(TypedDict):
|
||||||
|
embedding: List[float]
|
||||||
|
inputTextTokenCount: int
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanG1EmbeddingRequest(TypedDict):
|
||||||
|
inputText: str
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanG1EmbeddingResponse(TypedDict):
|
||||||
|
embedding: List[float]
|
||||||
|
inputTextTokenCount: int
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanMultimodalEmbeddingConfig(TypedDict):
|
||||||
|
outputEmbeddingLength: Literal[256, 384, 1024]
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanMultimodalEmbeddingRequest(TypedDict, total=False):
|
||||||
|
inputText: str
|
||||||
|
inputImage: str
|
||||||
|
embeddingConfig: AmazonTitanMultimodalEmbeddingConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanMultimodalEmbeddingResponse(TypedDict):
|
||||||
|
embedding: List[float]
|
||||||
|
inputTextTokenCount: int
|
||||||
|
message: str # Specifies any errors that occur during generation.
|
||||||
|
|
||||||
|
|
||||||
|
AmazonEmbeddingRequest = Union[
|
||||||
|
AmazonTitanMultimodalEmbeddingRequest,
|
||||||
|
AmazonTitanV2EmbeddingRequest,
|
||||||
|
AmazonTitanG1EmbeddingRequest,
|
||||||
|
]
|
||||||
|
|
|
@ -699,7 +699,7 @@ class ModelResponse(OpenAIObject):
|
||||||
class Embedding(OpenAIObject):
|
class Embedding(OpenAIObject):
|
||||||
embedding: Union[list, str] = []
|
embedding: Union[list, str] = []
|
||||||
index: int
|
index: int
|
||||||
object: str
|
object: Literal["embedding"]
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
|
@ -721,7 +721,7 @@ class EmbeddingResponse(OpenAIObject):
|
||||||
data: Optional[List] = None
|
data: Optional[List] = None
|
||||||
"""The actual embedding value"""
|
"""The actual embedding value"""
|
||||||
|
|
||||||
object: str
|
object: Literal["list"]
|
||||||
"""The object type, which is always "embedding" """
|
"""The object type, which is always "embedding" """
|
||||||
|
|
||||||
usage: Optional[Usage] = None
|
usage: Optional[Usage] = None
|
||||||
|
@ -732,11 +732,10 @@ class EmbeddingResponse(OpenAIObject):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model=None,
|
model: Optional[str] = None,
|
||||||
usage=None,
|
usage: Optional[Usage] = None,
|
||||||
stream=False,
|
|
||||||
response_ms=None,
|
response_ms=None,
|
||||||
data=None,
|
data: Optional[List] = None,
|
||||||
hidden_params=None,
|
hidden_params=None,
|
||||||
_response_headers=None,
|
_response_headers=None,
|
||||||
**params,
|
**params,
|
||||||
|
@ -760,7 +759,7 @@ class EmbeddingResponse(OpenAIObject):
|
||||||
self._response_headers = _response_headers
|
self._response_headers = _response_headers
|
||||||
|
|
||||||
model = model
|
model = model
|
||||||
super().__init__(model=model, object=object, data=data, usage=usage)
|
super().__init__(model=model, object=object, data=data, usage=usage) # type: ignore
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
# Define custom behavior for the 'in' operator
|
# Define custom behavior for the 'in' operator
|
||||||
|
|
|
@ -854,6 +854,7 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||||
if cached_result is not None:
|
if cached_result is not None:
|
||||||
|
print_verbose("Cache Hit!")
|
||||||
if "detail" in cached_result:
|
if "detail" in cached_result:
|
||||||
# implies an error occurred
|
# implies an error occurred
|
||||||
pass
|
pass
|
||||||
|
@ -935,7 +936,10 @@ def client(original_function):
|
||||||
args=(cached_result, start_time, end_time, cache_hit),
|
args=(cached_result, start_time, end_time, cache_hit),
|
||||||
).start()
|
).start()
|
||||||
return cached_result
|
return cached_result
|
||||||
|
else:
|
||||||
|
print_verbose(
|
||||||
|
"Cache Miss! on key - {}".format(preset_cache_key)
|
||||||
|
)
|
||||||
# CHECK MAX TOKENS
|
# CHECK MAX TOKENS
|
||||||
if (
|
if (
|
||||||
kwargs.get("max_tokens", None) is not None
|
kwargs.get("max_tokens", None) is not None
|
||||||
|
@ -1005,7 +1009,7 @@ def client(original_function):
|
||||||
litellm.cache is not None
|
litellm.cache is not None
|
||||||
and str(original_function.__name__)
|
and str(original_function.__name__)
|
||||||
in litellm.cache.supported_call_types
|
in litellm.cache.supported_call_types
|
||||||
) and (kwargs.get("cache", {}).get("no-store", False) != True):
|
) and (kwargs.get("cache", {}).get("no-store", False) is not True):
|
||||||
litellm.cache.add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
|
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||||
|
@ -1404,10 +1408,10 @@ def client(original_function):
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
result = await original_function(*args, **kwargs)
|
result = await original_function(*args, **kwargs)
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
if "stream" in kwargs and kwargs["stream"] == True:
|
if "stream" in kwargs and kwargs["stream"] is True:
|
||||||
if (
|
if (
|
||||||
"complete_response" in kwargs
|
"complete_response" in kwargs
|
||||||
and kwargs["complete_response"] == True
|
and kwargs["complete_response"] is True
|
||||||
):
|
):
|
||||||
chunks = []
|
chunks = []
|
||||||
for idx, chunk in enumerate(result):
|
for idx, chunk in enumerate(result):
|
||||||
|
@ -11734,3 +11738,13 @@ def is_cached_message(message: AllMessageValues) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_base64_encoded(s: str) -> bool:
|
||||||
|
try:
|
||||||
|
# Try to decode the string
|
||||||
|
decoded_bytes = base64.b64decode(s, validate=True)
|
||||||
|
# Check if the original string can be re-encoded to the same string
|
||||||
|
return base64.b64encode(decoded_bytes).decode("utf-8") == s
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue