forked from phoenix/litellm-mirror
Bedrock Embeddings refactor + model support (#5462)
* refactor(bedrock): initial commit to refactor bedrock to a folder Improve code readability + maintainability * refactor: more refactor work * fix: fix imports * feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats * fix: fix linting errors * test: skip test on end of life model * fix(cohere/embed.py): fix linting error * fix(cohere/embed.py): fix typing * fix(cohere/embed.py): fix post-call logging for cohere embedding call * test(test_embeddings.py): fix error message assertion in test
This commit is contained in:
parent
6fb82aaf75
commit
37f9705d6e
21 changed files with 1946 additions and 1659 deletions
|
@ -1,12 +1,12 @@
|
|||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: python3 -m mypy --ignore-missing-imports
|
||||
language: system
|
||||
types: [python]
|
||||
files: ^litellm/
|
||||
# - id: mypy
|
||||
# name: mypy
|
||||
# entry: python3 -m mypy --ignore-missing-imports
|
||||
# language: system
|
||||
# types: [python]
|
||||
# files: ^litellm/
|
||||
- id: isort
|
||||
name: isort
|
||||
entry: isort
|
||||
|
|
|
@ -118,6 +118,8 @@ in_memory_llm_clients_cache: dict = {}
|
|||
safe_memory_mode: bool = False
|
||||
### DEFAULT AZURE API VERSION ###
|
||||
AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest
|
||||
### COHERE EMBEDDINGS DEFAULT TYPE ###
|
||||
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE = "search_document"
|
||||
### GUARDRAILS ###
|
||||
llamaguard_model_name: Optional[str] = None
|
||||
openai_moderations_model_name: Optional[str] = None
|
||||
|
@ -880,13 +882,13 @@ from .llms.sagemaker.sagemaker import SagemakerConfig
|
|||
from .llms.ollama import OllamaConfig
|
||||
from .llms.ollama_chat import OllamaChatConfig
|
||||
from .llms.maritalk import MaritTalkConfig
|
||||
from .llms.bedrock_httpx import (
|
||||
from .llms.bedrock.chat import (
|
||||
AmazonCohereChatConfig,
|
||||
AmazonConverseConfig,
|
||||
BEDROCK_CONVERSE_MODELS,
|
||||
bedrock_tool_name_mappings,
|
||||
)
|
||||
from .llms.bedrock import (
|
||||
from .llms.bedrock.common_utils import (
|
||||
AmazonTitanConfig,
|
||||
AmazonAI21Config,
|
||||
AmazonAnthropicConfig,
|
||||
|
|
|
@ -608,17 +608,17 @@ class Logging:
|
|||
self.model_call_details["litellm_params"]["metadata"][
|
||||
"hidden_params"
|
||||
] = result._hidden_params
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
)
|
||||
)
|
||||
)
|
||||
else: # streaming chunks + image gen.
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,7 @@
|
|||
# What is this?
|
||||
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
||||
## V1 - covers cohere + anthropic claude-3 support
|
||||
"""
|
||||
Manages calling Bedrock's `/converse` API + `/invoke` API
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
|
@ -28,7 +29,7 @@ import requests # type: ignore
|
|||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching import DualCache, InMemoryCache
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
|
@ -39,21 +40,16 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
)
|
||||
from litellm.types.llms.bedrock import *
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionDeltaChunk,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.utils import Choices
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
from litellm.types.utils import Message
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret
|
||||
|
||||
from .base import BaseLLM
|
||||
from .base_aws_llm import BaseAWSLLM
|
||||
from .bedrock import BedrockError, ModelResponseIterator, convert_messages_to_prompt
|
||||
from .prompt_templates.factory import (
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..prompt_templates.factory import (
|
||||
_bedrock_converse_messages_pt,
|
||||
_bedrock_tools_pt,
|
||||
cohere_message_pt,
|
||||
|
@ -64,6 +60,7 @@ from .prompt_templates.factory import (
|
|||
parse_xml_params,
|
||||
prompt_factory,
|
||||
)
|
||||
from .common_utils import BedrockError, ModelResponseIterator, get_runtime_endpoint
|
||||
|
||||
BEDROCK_CONVERSE_MODELS = [
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
|
@ -727,22 +724,13 @@ class BedrockLLM(BaseAWSLLM):
|
|||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url = ""
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if api_base is not None:
|
||||
endpoint_url = api_base
|
||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
endpoint_url = get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
if (stream is not None and stream == True) and provider != "ai21":
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||
|
@ -1561,21 +1549,11 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url = ""
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if api_base is not None:
|
||||
endpoint_url = api_base
|
||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
endpoint_url = get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||
else:
|
773
litellm/llms/bedrock/common_utils.py
Normal file
773
litellm/llms/bedrock/common_utils.py
Normal file
|
@ -0,0 +1,773 @@
|
|||
"""
|
||||
Common utilities used across bedrock chat/embedding/image generation
|
||||
"""
|
||||
|
||||
import os
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class AmazonBedrockGlobalConfig:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Mapping of common auth params across bedrock/vertex/azure/watsonx
|
||||
"""
|
||||
return {"region_name": "aws_region_name"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.aws-services.info/bedrock.html
|
||||
"""
|
||||
return [
|
||||
"eu-west-1",
|
||||
"eu-west-3",
|
||||
"eu-central-1",
|
||||
]
|
||||
|
||||
|
||||
class AmazonTitanConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||
|
||||
Supported Params for the Amazon Titan models:
|
||||
|
||||
- `maxTokenCount` (integer) max tokens,
|
||||
- `stopSequences` (string[]) list of stop sequence strings
|
||||
- `temperature` (float) temperature for model,
|
||||
- `topP` (int) top p for model
|
||||
"""
|
||||
|
||||
maxTokenCount: Optional[int] = None
|
||||
stopSequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AmazonAnthropicClaude3Config:
|
||||
"""
|
||||
Reference:
|
||||
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
||||
|
||||
Supported Params for the Amazon / Anthropic Claude 3 models:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens. Default is 4096
|
||||
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
||||
- `temperature` Optional (float) The amount of randomness injected into the response
|
||||
- `top_p` Optional (float) Use nucleus sampling.
|
||||
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
||||
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
|
||||
anthropic_version: Optional[str] = "bedrock-2023-05-31"
|
||||
system: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"max_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class AmazonAnthropicConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
|
||||
Supported Params for the Amazon / Anthropic models:
|
||||
|
||||
- `max_tokens_to_sample` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `top_k` (integer) top k,
|
||||
- `top_p` (integer) top p,
|
||||
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
||||
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[int] = litellm.max_tokens
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
anthropic_version: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(
|
||||
self,
|
||||
):
|
||||
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "stream" and value == True:
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class AmazonCohereConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||
|
||||
Supported Params for the Amazon / Cohere models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `return_likelihood` (string) n/a
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
return_likelihood: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
return_likelihood: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AmazonAI21Config:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
|
||||
Supported Params for the Amazon / AI21 models:
|
||||
|
||||
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
||||
|
||||
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
||||
|
||||
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
||||
|
||||
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
||||
|
||||
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
||||
|
||||
- `presencePenalty` (object): Placeholder for presence penalty object.
|
||||
|
||||
- `countPenalty` (object): Placeholder for count penalty object.
|
||||
"""
|
||||
|
||||
maxTokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
stopSequences: Optional[list] = None
|
||||
frequencePenalty: Optional[dict] = None
|
||||
presencePenalty: Optional[dict] = None
|
||||
countPenalty: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[float] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
frequencePenalty: Optional[dict] = None,
|
||||
presencePenalty: Optional[dict] = None,
|
||||
countPenalty: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
|
||||
class AmazonLlamaConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
||||
|
||||
Supported Params for the Amazon / Meta Llama models:
|
||||
|
||||
- `max_gen_len` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
"""
|
||||
|
||||
max_gen_len: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AmazonMistralConfig:
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
||||
Supported Params for the Amazon / Mistral models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
|
||||
- `top_k` (float) top k for model
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[float] = None
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AmazonStabilityConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Supported Params for the Amazon / Stable Diffusion models:
|
||||
|
||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||
|
||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||
|
||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||
|
||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
|
||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def add_custom_header(headers):
|
||||
"""Closure to capture the headers and add them."""
|
||||
|
||||
def callback(request, **kwargs):
|
||||
"""Actual callback function that Boto3 will call."""
|
||||
for header_name, header_value in headers.items():
|
||||
request.headers.add_header(header_name, header_value)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
def init_bedrock_client(
|
||||
region_name=None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
params_to_check[i] = get_secret(param)
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
) = params_to_check
|
||||
|
||||
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||
|
||||
### SET REGION NAME
|
||||
if region_name:
|
||||
pass
|
||||
elif aws_region_name:
|
||||
region_name = aws_region_name
|
||||
elif litellm_aws_region_name:
|
||||
region_name = litellm_aws_region_name
|
||||
elif standard_aws_region_name:
|
||||
region_name = standard_aws_region_name
|
||||
else:
|
||||
raise BedrockError(
|
||||
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if aws_bedrock_runtime_endpoint:
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint:
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
|
||||
|
||||
import boto3
|
||||
|
||||
if isinstance(timeout, float):
|
||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||
elif isinstance(timeout, httpx.Timeout):
|
||||
config = boto3.session.Config(
|
||||
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||
)
|
||||
else:
|
||||
config = boto3.session.Config()
|
||||
|
||||
### CHECK STS ###
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
# use sts if role name passed in
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
elif aws_access_key_id is not None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
elif aws_profile_name is not None:
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
|
||||
client = boto3.Session(profile_name=aws_profile_name).client(
|
||||
service_name="bedrock-runtime",
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automatically reads env variables
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
if extra_headers:
|
||||
client.meta.events.register(
|
||||
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def get_runtime_endpoint(
|
||||
api_base: Optional[str],
|
||||
aws_bedrock_runtime_endpoint: Optional[str],
|
||||
aws_region_name: str,
|
||||
) -> str:
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if api_base is not None:
|
||||
endpoint_url = api_base
|
||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
return endpoint_url
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, model_response):
|
||||
self.model_response = model_response
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
149
litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py
Normal file
149
litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Convers
|
||||
- G1 request format
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonTitanG1EmbeddingRequest,
|
||||
AmazonTitanG1EmbeddingResponse,
|
||||
AmazonTitanV2EmbeddingRequest,
|
||||
AmazonTitanV2EmbeddingResponse,
|
||||
)
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
||||
class AmazonTitanG1Config:
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def _transform_request(
|
||||
self, input: str, inference_params: dict
|
||||
) -> AmazonTitanG1EmbeddingRequest:
|
||||
return AmazonTitanG1EmbeddingRequest(inputText=input)
|
||||
|
||||
def _transform_response(
|
||||
self, response_list: List[dict], model: str
|
||||
) -> EmbeddingResponse:
|
||||
total_prompt_tokens = 0
|
||||
|
||||
transformed_responses: List[Embedding] = []
|
||||
for index, response in enumerate(response_list):
|
||||
_parsed_response = AmazonTitanG1EmbeddingResponse(**response) # type: ignore
|
||||
transformed_responses.append(
|
||||
Embedding(
|
||||
embedding=_parsed_response["embedding"],
|
||||
index=index,
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=total_prompt_tokens,
|
||||
)
|
||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
||||
|
||||
|
||||
class AmazonTitanV2Config:
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
|
||||
normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true
|
||||
dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256.
|
||||
"""
|
||||
|
||||
normalize: Optional[bool] = None
|
||||
dimensions: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def _transform_request(
|
||||
self, input: str, inference_params: dict
|
||||
) -> AmazonTitanV2EmbeddingRequest:
|
||||
return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore
|
||||
|
||||
def _transform_response(
|
||||
self, response_list: List[dict], model: str
|
||||
) -> EmbeddingResponse:
|
||||
total_prompt_tokens = 0
|
||||
|
||||
transformed_responses: List[Embedding] = []
|
||||
for index, response in enumerate(response_list):
|
||||
_parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore
|
||||
transformed_responses.append(
|
||||
Embedding(
|
||||
embedding=_parsed_response["embedding"],
|
||||
index=index,
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=total_prompt_tokens,
|
||||
)
|
||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan multimodal /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-mm.html
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonTitanMultimodalEmbeddingConfig,
|
||||
AmazonTitanMultimodalEmbeddingRequest,
|
||||
AmazonTitanMultimodalEmbeddingResponse,
|
||||
)
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
from litellm.utils import is_base64_encoded
|
||||
|
||||
|
||||
def _transform_request(
|
||||
input: str, inference_params: dict
|
||||
) -> AmazonTitanMultimodalEmbeddingRequest:
|
||||
## check if b64 encoded str or not ##
|
||||
is_encoded = is_base64_encoded(input)
|
||||
if is_encoded: # check if string is b64 encoded image or not
|
||||
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputImage=input)
|
||||
else:
|
||||
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputText=input)
|
||||
|
||||
for k, v in inference_params.items():
|
||||
transformed_request[k] = v # type: ignore
|
||||
|
||||
return transformed_request
|
||||
|
||||
|
||||
def _transform_response(response_list: List[dict], model: str) -> EmbeddingResponse:
|
||||
|
||||
total_prompt_tokens = 0
|
||||
transformed_responses: List[Embedding] = []
|
||||
for index, response in enumerate(response_list):
|
||||
_parsed_response = AmazonTitanMultimodalEmbeddingResponse(**response) # type: ignore
|
||||
transformed_responses.append(
|
||||
Embedding(
|
||||
embedding=_parsed_response["embedding"], index=index, object="embedding"
|
||||
)
|
||||
)
|
||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=total_prompt_tokens,
|
||||
)
|
||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
86
litellm/llms/bedrock/embed/amazon_titan_v2_transformation.py
Normal file
86
litellm/llms/bedrock/embed/amazon_titan_v2_transformation.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan V2 /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Convers
|
||||
- v2 request format
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonTitanV2EmbeddingRequest,
|
||||
AmazonTitanV2EmbeddingResponse,
|
||||
)
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
||||
class AmazonTitanV2Config:
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
|
||||
normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true
|
||||
dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256.
|
||||
"""
|
||||
|
||||
normalize: Optional[bool] = None
|
||||
dimensions: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def _transform_request(
|
||||
self, input: str, inference_params: dict
|
||||
) -> AmazonTitanV2EmbeddingRequest:
|
||||
return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore
|
||||
|
||||
def _transform_response(
|
||||
self, response_list: List[dict], model: str
|
||||
) -> EmbeddingResponse:
|
||||
total_prompt_tokens = 0
|
||||
|
||||
transformed_responses: List[Embedding] = []
|
||||
for index, response in enumerate(response_list):
|
||||
_parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore
|
||||
transformed_responses.append(
|
||||
Embedding(
|
||||
embedding=_parsed_response["embedding"],
|
||||
index=index,
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=total_prompt_tokens,
|
||||
)
|
||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
25
litellm/llms/bedrock/embed/cohere_transformation.py
Normal file
25
litellm/llms/bedrock/embed/cohere_transformation.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.bedrock import CohereEmbeddingRequest, CohereEmbeddingResponse
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse
|
||||
|
||||
|
||||
def _transform_request(
|
||||
input: List[str], inference_params: dict
|
||||
) -> CohereEmbeddingRequest:
|
||||
transformed_request = CohereEmbeddingRequest(
|
||||
texts=input,
|
||||
input_type=litellm.COHERE_DEFAULT_EMBEDDING_INPUT_TYPE, # type: ignore
|
||||
)
|
||||
|
||||
for k, v in inference_params.items():
|
||||
transformed_request[k] = v # type: ignore
|
||||
|
||||
return transformed_request
|
498
litellm/llms/bedrock/embed/embedding.py
Normal file
498
litellm/llms/bedrock/embed/embedding.py
Normal file
|
@ -0,0 +1,498 @@
|
|||
"""
|
||||
Handles embedding calls to Bedrock's `/invoke` endpoint
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm.llms.cohere.embed import embedding as cohere_embedding
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
from ...base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError, get_runtime_endpoint
|
||||
from .amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||
from .amazon_titan_multimodal_transformation import (
|
||||
_transform_request as amazon_multimodal_transform_request,
|
||||
)
|
||||
from .amazon_titan_multimodal_transformation import (
|
||||
_transform_response as amazon_multimodal_transform_response,
|
||||
)
|
||||
from .amazon_titan_v2_transformation import AmazonTitanV2Config
|
||||
from .cohere_transformation import _transform_request as cohere_transform_request
|
||||
|
||||
|
||||
class BedrockEmbedding(BaseAWSLLM):
|
||||
def _load_credentials(
|
||||
self,
|
||||
optional_params: dict,
|
||||
) -> Tuple[Any, str]:
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
return credentials, aws_region_name
|
||||
|
||||
async def async_embeddings(self):
|
||||
pass
|
||||
|
||||
def _make_sync_call(
|
||||
self,
|
||||
client: Optional[HTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
) -> dict:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = _get_httpx_client(_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
try:
|
||||
response = client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return response.json()
|
||||
|
||||
def _single_func_embeddings(
|
||||
self,
|
||||
client: Optional[HTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
batch_data: List[dict],
|
||||
credentials: Any,
|
||||
extra_headers: Optional[dict],
|
||||
endpoint_url: str,
|
||||
aws_region_name: str,
|
||||
model: str,
|
||||
logging_obj: Any,
|
||||
):
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
responses: List[dict] = []
|
||||
for data in batch_data:
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepped.url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
response = self._make_sync_call(
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
api_base=prepped.url,
|
||||
headers=prepped.headers,
|
||||
data=data,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
responses.append(response)
|
||||
|
||||
returned_response: Optional[EmbeddingResponse] = None
|
||||
|
||||
## TRANSFORM RESPONSE ##
|
||||
if model == "amazon.titan-embed-image-v1":
|
||||
returned_response = amazon_multimodal_transform_response(
|
||||
response_list=responses, model=model
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v1":
|
||||
returned_response = AmazonTitanG1Config()._transform_response(
|
||||
response_list=responses, model=model
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v2:0":
|
||||
returned_response = AmazonTitanV2Config()._transform_response(
|
||||
response_list=responses, model=model
|
||||
)
|
||||
|
||||
if returned_response is None:
|
||||
raise Exception(
|
||||
"Unable to map model response to known provider format. model={}".format(
|
||||
model
|
||||
)
|
||||
)
|
||||
|
||||
return returned_response
|
||||
|
||||
def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: List[str],
|
||||
api_base: Optional[str],
|
||||
model_response: EmbeddingResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
aembedding: Optional[bool],
|
||||
extra_headers: Optional[dict],
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
) -> EmbeddingResponse:
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
credentials, aws_region_name = self._load_credentials(optional_params)
|
||||
|
||||
### TRANSFORMATION ###
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
modelId = (
|
||||
optional_params.pop("model_id", None) or model
|
||||
) # default to model if not passed
|
||||
|
||||
data: Optional[CohereEmbeddingRequest] = None
|
||||
batch_data: Optional[List] = None
|
||||
if provider == "cohere":
|
||||
data = cohere_transform_request(
|
||||
input=input, inference_params=inference_params
|
||||
)
|
||||
elif provider == "amazon" and model in [
|
||||
"amazon.titan-embed-image-v1",
|
||||
"amazon.titan-embed-text-v1",
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
]:
|
||||
batch_data = []
|
||||
for i in input:
|
||||
if model == "amazon.titan-embed-image-v1":
|
||||
transformed_request: AmazonEmbeddingRequest = (
|
||||
amazon_multimodal_transform_request(
|
||||
input=i, inference_params=inference_params
|
||||
)
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v1":
|
||||
transformed_request = AmazonTitanG1Config()._transform_request(
|
||||
input=i, inference_params=inference_params
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v2:0":
|
||||
transformed_request = AmazonTitanV2Config()._transform_request(
|
||||
input=i, inference_params=inference_params
|
||||
)
|
||||
batch_data.append(transformed_request)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url = get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
),
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||
|
||||
if batch_data is not None:
|
||||
return self._single_func_embeddings(
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
batch_data=batch_data,
|
||||
credentials=credentials,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=endpoint_url,
|
||||
aws_region_name=aws_region_name,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
elif data is None:
|
||||
raise Exception("Unable to map request to provider")
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
## ROUTING ##
|
||||
return cohere_embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
data=data, # type: ignore
|
||||
complete_api_base=prepped.url,
|
||||
api_key=None,
|
||||
aembedding=aembedding,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
headers=prepped.headers,
|
||||
)
|
||||
|
||||
# def _embedding_func_single(
|
||||
# model: str,
|
||||
# input: str,
|
||||
# client: Any,
|
||||
# optional_params=None,
|
||||
# encoding=None,
|
||||
# logging_obj=None,
|
||||
# ):
|
||||
# if isinstance(input, str) is False:
|
||||
# raise BedrockError(
|
||||
# message="Bedrock Embedding API input must be type str | List[str]",
|
||||
# status_code=400,
|
||||
# )
|
||||
# # logic for parsing in - calling - parsing out model embedding calls
|
||||
# ## FORMAT EMBEDDING INPUT ##
|
||||
# provider = model.split(".")[0]
|
||||
# inference_params = copy.deepcopy(optional_params)
|
||||
# inference_params.pop(
|
||||
# "user", None
|
||||
# ) # make sure user is not passed in for bedrock call
|
||||
# modelId = (
|
||||
# optional_params.pop("model_id", None) or model
|
||||
# ) # default to model if not passed
|
||||
# if provider == "amazon":
|
||||
# input = input.replace(os.linesep, " ")
|
||||
# data = {"inputText": input, **inference_params}
|
||||
# # data = json.dumps(data)
|
||||
# elif provider == "cohere":
|
||||
# inference_params["input_type"] = inference_params.get(
|
||||
# "input_type", "search_document"
|
||||
# ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||
# data = {"texts": [input], **inference_params} # type: ignore
|
||||
# body = json.dumps(data).encode("utf-8") # type: ignore
|
||||
# ## LOGGING
|
||||
# request_str = f"""
|
||||
# response = client.invoke_model(
|
||||
# body={body},
|
||||
# modelId={modelId},
|
||||
# accept="*/*",
|
||||
# contentType="application/json",
|
||||
# )""" # type: ignore
|
||||
# logging_obj.pre_call(
|
||||
# input=input,
|
||||
# api_key="", # boto3 is used for init.
|
||||
# additional_args={
|
||||
# "complete_input_dict": {"model": modelId, "texts": input},
|
||||
# "request_str": request_str,
|
||||
# },
|
||||
# )
|
||||
# try:
|
||||
# response = client.invoke_model(
|
||||
# body=body,
|
||||
# modelId=modelId,
|
||||
# accept="*/*",
|
||||
# contentType="application/json",
|
||||
# )
|
||||
# response_body = json.loads(response.get("body").read())
|
||||
# ## LOGGING
|
||||
# logging_obj.post_call(
|
||||
# input=input,
|
||||
# api_key="",
|
||||
# additional_args={"complete_input_dict": data},
|
||||
# original_response=json.dumps(response_body),
|
||||
# )
|
||||
# if provider == "cohere":
|
||||
# response = response_body.get("embeddings")
|
||||
# # flatten list
|
||||
# response = [item for sublist in response for item in sublist]
|
||||
# return response
|
||||
# elif provider == "amazon":
|
||||
# return response_body.get("embedding")
|
||||
# except Exception as e:
|
||||
# raise BedrockError(
|
||||
# message=f"Embedding Error with model {model}: {e}", status_code=500
|
||||
# )
|
||||
|
||||
# def embedding(
|
||||
# model: str,
|
||||
# input: Union[list, str],
|
||||
# model_response: litellm.EmbeddingResponse,
|
||||
# api_key: Optional[str] = None,
|
||||
# logging_obj=None,
|
||||
# optional_params=None,
|
||||
# encoding=None,
|
||||
# ):
|
||||
# ### BOTO3 INIT ###
|
||||
# # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
# aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
# aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
# aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
# aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
# aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
# aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
# "aws_bedrock_runtime_endpoint", None
|
||||
# )
|
||||
# aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# # use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
# client = init_bedrock_client(
|
||||
# aws_access_key_id=aws_access_key_id,
|
||||
# aws_secret_access_key=aws_secret_access_key,
|
||||
# aws_region_name=aws_region_name,
|
||||
# aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
# aws_web_identity_token=aws_web_identity_token,
|
||||
# aws_role_name=aws_role_name,
|
||||
# aws_session_name=aws_session_name,
|
||||
# )
|
||||
# if isinstance(input, str):
|
||||
# ## Embedding Call
|
||||
# embeddings = [
|
||||
# _embedding_func_single(
|
||||
# model,
|
||||
# input,
|
||||
# optional_params=optional_params,
|
||||
# client=client,
|
||||
# logging_obj=logging_obj,
|
||||
# )
|
||||
# ]
|
||||
# elif isinstance(input, list):
|
||||
# ## Embedding Call - assuming this is a List[str]
|
||||
# embeddings = [
|
||||
# _embedding_func_single(
|
||||
# model,
|
||||
# i,
|
||||
# optional_params=optional_params,
|
||||
# client=client,
|
||||
# logging_obj=logging_obj,
|
||||
# )
|
||||
# for i in input
|
||||
# ] # [TODO]: make these parallel calls
|
||||
# else:
|
||||
# # enters this branch if input = int, ex. input=2
|
||||
# raise BedrockError(
|
||||
# message="Bedrock Embedding API input must be type str | List[str]",
|
||||
# status_code=400,
|
||||
# )
|
||||
|
||||
# ## Populate OpenAI compliant dictionary
|
||||
# embedding_response = []
|
||||
# for idx, embedding in enumerate(embeddings):
|
||||
# embedding_response.append(
|
||||
# {
|
||||
# "object": "embedding",
|
||||
# "index": idx,
|
||||
# "embedding": embedding,
|
||||
# }
|
||||
# )
|
||||
# model_response.object = "list"
|
||||
# model_response.data = embedding_response
|
||||
# model_response.model = model
|
||||
# input_tokens = 0
|
||||
|
||||
# input_str = "".join(input)
|
||||
|
||||
# input_tokens += len(encoding.encode(input_str))
|
||||
|
||||
# usage = Usage(
|
||||
# prompt_tokens=input_tokens,
|
||||
# completion_tokens=0,
|
||||
# total_tokens=input_tokens + 0,
|
||||
# )
|
||||
# model_response.usage = usage
|
||||
|
||||
# return model_response
|
127
litellm/llms/bedrock/image_generation.py
Normal file
127
litellm/llms/bedrock/image_generation.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
"""
|
||||
Handles image gen calls to Bedrock's `/invoke` endpoint
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
from .common_utils import BedrockError, init_bedrock_client
|
||||
|
||||
|
||||
def image_generation(
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
timeout=None,
|
||||
logging_obj=None,
|
||||
aimg_generation=False,
|
||||
):
|
||||
"""
|
||||
Bedrock Image Gen endpoint support
|
||||
"""
|
||||
### BOTO3 INIT ###
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
### FORMAT IMAGE GENERATION INPUT ###
|
||||
modelId = model
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
data = {}
|
||||
if provider == "stability":
|
||||
prompt = prompt.replace(os.linesep, " ")
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonStabilityConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=422, message=f"Unsupported model={model}, passed in"
|
||||
)
|
||||
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
body={body}, # type: ignore
|
||||
modelId={modelId},
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="", # boto3 is used for init.
|
||||
additional_args={
|
||||
"complete_input_dict": {"model": modelId, "texts": prompt},
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
try:
|
||||
response = client.invoke_model(
|
||||
body=body,
|
||||
modelId=modelId,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=json.dumps(response_body),
|
||||
)
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message=f"Embedding Error with model {model}: {e}", status_code=500
|
||||
)
|
||||
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
if response_body is None:
|
||||
raise Exception("Error in response object format")
|
||||
|
||||
if model_response is None:
|
||||
model_response = ImageResponse()
|
||||
|
||||
image_list: List[Image] = []
|
||||
for artifact in response_body["artifacts"]:
|
||||
_image = Image(b64_json=artifact["base64"])
|
||||
image_list.append(_image)
|
||||
|
||||
model_response.data = image_list
|
||||
return model_response
|
|
@ -76,7 +76,7 @@ async def async_embedding(
|
|||
data: dict,
|
||||
input: list,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
api_base: str,
|
||||
|
@ -98,16 +98,35 @@ async def async_embedding(
|
|||
)
|
||||
## COMPLETION CALL
|
||||
if client is None:
|
||||
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||
client = AsyncHTTPHandler(concurrent_limit=1, timeout=timeout)
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
try:
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
except httpx.HTTPStatusError as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=e.response.text,
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
original_response=response.text,
|
||||
)
|
||||
|
||||
embeddings = response.json()["embeddings"]
|
||||
|
@ -130,27 +149,22 @@ def embedding(
|
|||
optional_params: dict,
|
||||
headers: dict,
|
||||
encoding: Any,
|
||||
data: Optional[dict] = None,
|
||||
complete_api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
aembedding: Optional[bool] = None,
|
||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None),
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
embed_url = "https://api.cohere.ai/v1/embed"
|
||||
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
|
||||
model = model
|
||||
data = {"model": model, "texts": input, **optional_params}
|
||||
data = data or {"model": model, "texts": input, **optional_params}
|
||||
|
||||
if "3" in model and "input_type" not in data:
|
||||
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
|
||||
data["input_type"] = "search_document"
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
## ROUTING
|
||||
if aembedding is True:
|
||||
return async_embedding(
|
||||
|
@ -166,9 +180,18 @@ def embedding(
|
|||
headers=headers,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(concurrent_limit=1)
|
||||
|
||||
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
|
|
@ -78,7 +78,6 @@ from .llms import (
|
|||
ai21,
|
||||
aleph_alpha,
|
||||
baseten,
|
||||
bedrock,
|
||||
clarifai,
|
||||
cloudflare,
|
||||
maritalk,
|
||||
|
@ -96,7 +95,9 @@ from .llms.anthropic.chat import AnthropicChatCompletion
|
|||
from .llms.anthropic.completion import AnthropicTextCompletion
|
||||
from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||
from .llms.azure_text import AzureTextCompletion
|
||||
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
||||
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
|
||||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||
from .llms.cohere import chat as cohere_chat
|
||||
from .llms.cohere import completion as cohere_completion # type: ignore
|
||||
from .llms.cohere import embed as cohere_embed
|
||||
|
@ -176,6 +177,7 @@ codestral_text_completions = CodestralTextCompletion()
|
|||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||
bedrock_embedding = BedrockEmbedding()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
vertex_multimodal_embedding = VertexMultimodalEmbedding()
|
||||
google_batch_embeddings = GoogleBatchEmbeddings()
|
||||
|
@ -3151,6 +3153,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
or custom_llm_provider == "watsonx"
|
||||
or custom_llm_provider == "cohere"
|
||||
or custom_llm_provider == "huggingface"
|
||||
or custom_llm_provider == "bedrock"
|
||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -3519,13 +3522,24 @@ def embedding(
|
|||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
response = bedrock.embedding(
|
||||
if isinstance(input, str):
|
||||
transformed_input = [input]
|
||||
else:
|
||||
transformed_input = input
|
||||
response = bedrock_embedding.embeddings(
|
||||
model=model,
|
||||
input=input,
|
||||
input=transformed_input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
aembedding=aembedding,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
print_verbose=print_verbose,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
elif custom_llm_provider == "triton":
|
||||
if api_base is None:
|
||||
|
@ -4493,7 +4507,7 @@ def image_generation(
|
|||
elif custom_llm_provider == "bedrock":
|
||||
if model is None:
|
||||
raise Exception("Model needs to be set for bedrock")
|
||||
model_response = bedrock.image_generation(
|
||||
model_response = bedrock_image_generation.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
|
|
|
@ -178,7 +178,7 @@ async def bedrock_proxy_route(
|
|||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
from litellm.llms.bedrock_httpx import BedrockConverseLLM
|
||||
from litellm.llms.bedrock.chat import BedrockConverseLLM
|
||||
|
||||
credentials: Credentials = BedrockConverseLLM().get_credentials()
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
|
|
@ -25,7 +25,7 @@ from litellm import (
|
|||
completion_cost,
|
||||
embedding,
|
||||
)
|
||||
from litellm.llms.bedrock_httpx import BedrockLLM, ToolBlock
|
||||
from litellm.llms.bedrock.chat import BedrockLLM, ToolBlock
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.prompt_templates.factory import _bedrock_tools_pt
|
||||
|
||||
|
|
|
@ -311,7 +311,17 @@ async def test_cohere_embedding3(custom_llm_provider):
|
|||
# test_cohere_embedding3()
|
||||
|
||||
|
||||
def test_bedrock_embedding_titan():
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"bedrock/amazon.titan-embed-text-v1",
|
||||
"bedrock/amazon.titan-embed-image-v1",
|
||||
"bedrock/amazon.titan-embed-text-v2:0",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("sync_mode", [True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_embedding_titan(model, sync_mode):
|
||||
try:
|
||||
# this tests if we support str input for bedrock embedding
|
||||
litellm.set_verbose = True
|
||||
|
@ -320,16 +330,23 @@ def test_bedrock_embedding_titan():
|
|||
|
||||
current_time = str(time.time())
|
||||
# DO NOT MAKE THE INPUT A LIST in this test
|
||||
response = embedding(
|
||||
model="bedrock/amazon.titan-embed-text-v1",
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
aws_region_name="us-west-2",
|
||||
)
|
||||
print(f"response:", response)
|
||||
if sync_mode:
|
||||
response = embedding(
|
||||
model=model,
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
aws_region_name="us-west-2",
|
||||
)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model=model,
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
aws_region_name="us-west-2",
|
||||
)
|
||||
print("response:", response)
|
||||
assert isinstance(
|
||||
response["data"][0]["embedding"], list
|
||||
), "Expected response to be a list"
|
||||
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
|
||||
print("type of first embedding:", type(response["data"][0]["embedding"][0]))
|
||||
assert all(
|
||||
isinstance(x, float) for x in response["data"][0]["embedding"]
|
||||
), "Expected response to be a list of floats"
|
||||
|
@ -339,13 +356,20 @@ def test_bedrock_embedding_titan():
|
|||
|
||||
start_time = time.time()
|
||||
|
||||
response = embedding(
|
||||
model="bedrock/amazon.titan-embed-text-v1",
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
)
|
||||
if sync_mode:
|
||||
response = embedding(
|
||||
model=model,
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model=model,
|
||||
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
|
||||
)
|
||||
print(response)
|
||||
|
||||
end_time = time.time()
|
||||
print(response._hidden_params)
|
||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||
|
||||
assert end_time - start_time < 0.1
|
||||
|
@ -392,13 +416,13 @@ def test_demo_tokens_as_input_to_embeddings_fails_for_titan():
|
|||
|
||||
with pytest.raises(
|
||||
litellm.BadRequestError,
|
||||
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
|
||||
match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: JSONArray, please reformat your input and try again."}',
|
||||
):
|
||||
litellm.embedding(model="amazon.titan-embed-text-v1", input=[[1]])
|
||||
|
||||
with pytest.raises(
|
||||
litellm.BadRequestError,
|
||||
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
|
||||
match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: Integer, please reformat your input and try again."}',
|
||||
):
|
||||
litellm.embedding(
|
||||
model="amazon.titan-embed-text-v1",
|
||||
|
|
|
@ -1,21 +1,25 @@
|
|||
import sys, os, uuid
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
from uuid import uuid4
|
||||
import tempfile
|
||||
from uuid import uuid4
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
from litellm import get_secret
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
||||
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
||||
|
@ -63,9 +67,7 @@ def test_oidc_github():
|
|||
reason="Cannot run without being in CircleCI Runner",
|
||||
)
|
||||
def test_oidc_circleci():
|
||||
secret_val = get_secret(
|
||||
"oidc/circleci/"
|
||||
)
|
||||
secret_val = get_secret("oidc/circleci/")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
|
@ -103,9 +105,7 @@ def test_oidc_circle_v1_with_amazon():
|
|||
# The purpose of this test is to get logs using the older v1 of the CircleCI OIDC token
|
||||
|
||||
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
||||
aws_role_name = (
|
||||
"arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
|
||||
)
|
||||
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
|
||||
aws_web_identity_token = "oidc/circleci/"
|
||||
|
||||
bllm = BedrockLLM()
|
||||
|
@ -116,6 +116,7 @@ def test_oidc_circle_v1_with_amazon():
|
|||
aws_session_name="assume-v1-session",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
|
||||
reason="Cannot run without being in CircleCI Runner",
|
||||
|
@ -124,9 +125,7 @@ def test_oidc_circle_v1_with_amazon_fips():
|
|||
# The purpose of this test is to validate that we can assume a role in a FIPS region
|
||||
|
||||
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
||||
aws_role_name = (
|
||||
"arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
|
||||
)
|
||||
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
|
||||
aws_web_identity_token = "oidc/circleci/"
|
||||
|
||||
bllm = BedrockConverseLLM()
|
||||
|
@ -143,9 +142,7 @@ def test_oidc_env_variable():
|
|||
# Create a unique environment variable name
|
||||
env_var_name = "OIDC_TEST_PATH_" + uuid4().hex
|
||||
os.environ[env_var_name] = "secret-" + uuid4().hex
|
||||
secret_val = get_secret(
|
||||
f"oidc/env/{env_var_name}"
|
||||
)
|
||||
secret_val = get_secret(f"oidc/env/{env_var_name}")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
|
@ -157,15 +154,13 @@ def test_oidc_env_variable():
|
|||
|
||||
def test_oidc_file():
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w+') as temp_file:
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as temp_file:
|
||||
secret_value = "secret-" + uuid4().hex
|
||||
temp_file.write(secret_value)
|
||||
temp_file.flush()
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
secret_val = get_secret(
|
||||
f"oidc/file/{temp_file_path}"
|
||||
)
|
||||
secret_val = get_secret(f"oidc/file/{temp_file_path}")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
|
@ -174,7 +169,7 @@ def test_oidc_file():
|
|||
|
||||
def test_oidc_env_path():
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w+') as temp_file:
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as temp_file:
|
||||
secret_value = "secret-" + uuid4().hex
|
||||
temp_file.write(secret_value)
|
||||
temp_file.flush()
|
||||
|
@ -187,9 +182,7 @@ def test_oidc_env_path():
|
|||
os.environ[env_var_name] = temp_file_path
|
||||
|
||||
# Test getting the secret using the environment variable
|
||||
secret_val = get_secret(
|
||||
f"oidc/env_path/{env_var_name}"
|
||||
)
|
||||
secret_val = get_secret(f"oidc/env_path/{env_var_name}")
|
||||
|
||||
print(f"secret_val: {redact_oidc_signature(secret_val)}")
|
||||
|
||||
|
|
|
@ -208,3 +208,62 @@ class ServerSentEvent:
|
|||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
|
||||
|
||||
|
||||
class CohereEmbeddingRequest(TypedDict, total=False):
|
||||
texts: Required[List[str]]
|
||||
input_type: Required[
|
||||
Literal["search_document", "search_query", "classification", "clustering"]
|
||||
]
|
||||
truncate: Literal["NONE", "START", "END"]
|
||||
embedding_types: Literal["float", "int8", "uint8", "binary", "ubinary"]
|
||||
|
||||
|
||||
class CohereEmbeddingResponse(TypedDict):
|
||||
embeddings: List[List[float]]
|
||||
id: str
|
||||
response_type: Literal["embedding_floats"]
|
||||
texts: List[str]
|
||||
|
||||
|
||||
class AmazonTitanV2EmbeddingRequest(TypedDict):
|
||||
inputText: str
|
||||
dimensions: int
|
||||
normalize: bool
|
||||
|
||||
|
||||
class AmazonTitanV2EmbeddingResponse(TypedDict):
|
||||
embedding: List[float]
|
||||
inputTextTokenCount: int
|
||||
|
||||
|
||||
class AmazonTitanG1EmbeddingRequest(TypedDict):
|
||||
inputText: str
|
||||
|
||||
|
||||
class AmazonTitanG1EmbeddingResponse(TypedDict):
|
||||
embedding: List[float]
|
||||
inputTextTokenCount: int
|
||||
|
||||
|
||||
class AmazonTitanMultimodalEmbeddingConfig(TypedDict):
|
||||
outputEmbeddingLength: Literal[256, 384, 1024]
|
||||
|
||||
|
||||
class AmazonTitanMultimodalEmbeddingRequest(TypedDict, total=False):
|
||||
inputText: str
|
||||
inputImage: str
|
||||
embeddingConfig: AmazonTitanMultimodalEmbeddingConfig
|
||||
|
||||
|
||||
class AmazonTitanMultimodalEmbeddingResponse(TypedDict):
|
||||
embedding: List[float]
|
||||
inputTextTokenCount: int
|
||||
message: str # Specifies any errors that occur during generation.
|
||||
|
||||
|
||||
AmazonEmbeddingRequest = Union[
|
||||
AmazonTitanMultimodalEmbeddingRequest,
|
||||
AmazonTitanV2EmbeddingRequest,
|
||||
AmazonTitanG1EmbeddingRequest,
|
||||
]
|
||||
|
|
|
@ -699,7 +699,7 @@ class ModelResponse(OpenAIObject):
|
|||
class Embedding(OpenAIObject):
|
||||
embedding: Union[list, str] = []
|
||||
index: int
|
||||
object: str
|
||||
object: Literal["embedding"]
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
|
@ -721,7 +721,7 @@ class EmbeddingResponse(OpenAIObject):
|
|||
data: Optional[List] = None
|
||||
"""The actual embedding value"""
|
||||
|
||||
object: str
|
||||
object: Literal["list"]
|
||||
"""The object type, which is always "embedding" """
|
||||
|
||||
usage: Optional[Usage] = None
|
||||
|
@ -732,11 +732,10 @@ class EmbeddingResponse(OpenAIObject):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
usage=None,
|
||||
stream=False,
|
||||
model: Optional[str] = None,
|
||||
usage: Optional[Usage] = None,
|
||||
response_ms=None,
|
||||
data=None,
|
||||
data: Optional[List] = None,
|
||||
hidden_params=None,
|
||||
_response_headers=None,
|
||||
**params,
|
||||
|
@ -760,7 +759,7 @@ class EmbeddingResponse(OpenAIObject):
|
|||
self._response_headers = _response_headers
|
||||
|
||||
model = model
|
||||
super().__init__(model=model, object=object, data=data, usage=usage)
|
||||
super().__init__(model=model, object=object, data=data, usage=usage) # type: ignore
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
|
|
|
@ -854,6 +854,7 @@ def client(original_function):
|
|||
)
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
if cached_result is not None:
|
||||
print_verbose("Cache Hit!")
|
||||
if "detail" in cached_result:
|
||||
# implies an error occurred
|
||||
pass
|
||||
|
@ -935,7 +936,10 @@ def client(original_function):
|
|||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
return cached_result
|
||||
|
||||
else:
|
||||
print_verbose(
|
||||
"Cache Miss! on key - {}".format(preset_cache_key)
|
||||
)
|
||||
# CHECK MAX TOKENS
|
||||
if (
|
||||
kwargs.get("max_tokens", None) is not None
|
||||
|
@ -1005,7 +1009,7 @@ def client(original_function):
|
|||
litellm.cache is not None
|
||||
and str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
) and (kwargs.get("cache", {}).get("no-store", False) != True):
|
||||
) and (kwargs.get("cache", {}).get("no-store", False) is not True):
|
||||
litellm.cache.add_cache(result, *args, **kwargs)
|
||||
|
||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||
|
@ -1404,10 +1408,10 @@ def client(original_function):
|
|||
# MODEL CALL
|
||||
result = await original_function(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
if "stream" in kwargs and kwargs["stream"] == True:
|
||||
if "stream" in kwargs and kwargs["stream"] is True:
|
||||
if (
|
||||
"complete_response" in kwargs
|
||||
and kwargs["complete_response"] == True
|
||||
and kwargs["complete_response"] is True
|
||||
):
|
||||
chunks = []
|
||||
for idx, chunk in enumerate(result):
|
||||
|
@ -11734,3 +11738,13 @@ def is_cached_message(message: AllMessageValues) -> bool:
|
|||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_base64_encoded(s: str) -> bool:
|
||||
try:
|
||||
# Try to decode the string
|
||||
decoded_bytes = base64.b64decode(s, validate=True)
|
||||
# Check if the original string can be re-encoded to the same string
|
||||
return base64.b64encode(decoded_bytes).decode("utf-8") == s
|
||||
except Exception:
|
||||
return False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue