implement embedding generation in supported inference providers

This commit is contained in:
Dinesh Yeduguru 2024-12-09 12:48:56 -08:00
parent b896be2311
commit e167e9eb93
16 changed files with 383 additions and 29 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from typing import * # noqa: F403
import json
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
@ -448,4 +449,18 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
model = await self.model_store.get_model(model_id)
embeddings = []
for content in contents:
input_text = str(content) if not isinstance(content, str) else content
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=model.provider_resource_id,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
return EmbeddingsResponse(embeddings=embeddings)