Bedrock Embeddings refactor + model support (#5462)

* refactor(bedrock): initial commit to refactor bedrock to a folder

Improve code readability + maintainability

* refactor: more refactor work

* fix: fix imports

* feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats

* fix: fix linting errors

* test: skip test on end of life model

* fix(cohere/embed.py): fix linting error

* fix(cohere/embed.py): fix typing

* fix(cohere/embed.py): fix post-call logging for cohere embedding call

* test(test_embeddings.py): fix error message assertion in test
This commit is contained in:
Krish Dholakia 2024-09-01 13:29:58 -07:00 committed by GitHub
parent 6fb82aaf75
commit 37f9705d6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1946 additions and 1659 deletions

View file

@ -208,3 +208,62 @@ class ServerSentEvent:
@override
def __repr__(self) -> str:
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
class CohereEmbeddingRequest(TypedDict, total=False):
texts: Required[List[str]]
input_type: Required[
Literal["search_document", "search_query", "classification", "clustering"]
]
truncate: Literal["NONE", "START", "END"]
embedding_types: Literal["float", "int8", "uint8", "binary", "ubinary"]
class CohereEmbeddingResponse(TypedDict):
embeddings: List[List[float]]
id: str
response_type: Literal["embedding_floats"]
texts: List[str]
class AmazonTitanV2EmbeddingRequest(TypedDict):
inputText: str
dimensions: int
normalize: bool
class AmazonTitanV2EmbeddingResponse(TypedDict):
embedding: List[float]
inputTextTokenCount: int
class AmazonTitanG1EmbeddingRequest(TypedDict):
inputText: str
class AmazonTitanG1EmbeddingResponse(TypedDict):
embedding: List[float]
inputTextTokenCount: int
class AmazonTitanMultimodalEmbeddingConfig(TypedDict):
outputEmbeddingLength: Literal[256, 384, 1024]
class AmazonTitanMultimodalEmbeddingRequest(TypedDict, total=False):
inputText: str
inputImage: str
embeddingConfig: AmazonTitanMultimodalEmbeddingConfig
class AmazonTitanMultimodalEmbeddingResponse(TypedDict):
embedding: List[float]
inputTextTokenCount: int
message: str # Specifies any errors that occur during generation.
AmazonEmbeddingRequest = Union[
AmazonTitanMultimodalEmbeddingRequest,
AmazonTitanV2EmbeddingRequest,
AmazonTitanG1EmbeddingRequest,
]

View file

@ -699,7 +699,7 @@ class ModelResponse(OpenAIObject):
class Embedding(OpenAIObject):
embedding: Union[list, str] = []
index: int
object: str
object: Literal["embedding"]
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
@ -721,7 +721,7 @@ class EmbeddingResponse(OpenAIObject):
data: Optional[List] = None
"""The actual embedding value"""
object: str
object: Literal["list"]
"""The object type, which is always "embedding" """
usage: Optional[Usage] = None
@ -732,11 +732,10 @@ class EmbeddingResponse(OpenAIObject):
def __init__(
self,
model=None,
usage=None,
stream=False,
model: Optional[str] = None,
usage: Optional[Usage] = None,
response_ms=None,
data=None,
data: Optional[List] = None,
hidden_params=None,
_response_headers=None,
**params,
@ -760,7 +759,7 @@ class EmbeddingResponse(OpenAIObject):
self._response_headers = _response_headers
model = model
super().__init__(model=model, object=object, data=data, usage=usage)
super().__init__(model=model, object=object, data=data, usage=usage) # type: ignore
def __contains__(self, key):
# Define custom behavior for the 'in' operator