Memory tests pass now

This commit is contained in:
Ashwin Bharambe 2024-12-15 20:55:06 -08:00
parent e51154964f
commit 59ce047aea
23 changed files with 122 additions and 81 deletions

View file

@ -9,21 +9,24 @@ import json
import uuid
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media,
interleaved_content_as_str,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
MODEL_ALIASES = [
@ -64,7 +67,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -449,7 +452,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []
@ -457,7 +460,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
assert not content_has_media(
content
), "Bedrock does not support media for embeddings"
input_text = interleaved_text_media_as_str(content)
input_text = interleaved_content_as_str(content)
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(