mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 17:59:52 +00:00
implement embedding generation in supported inference providers
This commit is contained in:
parent
b896be2311
commit
e167e9eb93
16 changed files with 383 additions and 29 deletions
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
import json
|
||||
|
||||
from botocore.client import BaseClient
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
|
@ -448,4 +449,18 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embeddings = []
|
||||
for content in contents:
|
||||
input_text = str(content) if not isinstance(content, str) else content
|
||||
input_body = {"inputText": input_text}
|
||||
body = json.dumps(input_body)
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=model.provider_resource_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
embeddings.append(response_body.get("embedding"))
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue