address feedback

This commit is contained in:
Dinesh Yeduguru 2024-12-11 16:24:37 -08:00
parent e167e9eb93
commit 5821ec9ef3
12 changed files with 61 additions and 76 deletions

View file

@ -20,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
model_aliases = [
@ -452,7 +454,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
model = await self.model_store.get_model(model_id)
embeddings = []
for content in contents:
input_text = str(content) if not isinstance(content, str) else content
assert not content_has_media(
content
), "Bedrock does not support media for embeddings"
input_text = interleaved_text_media_as_str(content)
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(

View file

@ -12,7 +12,6 @@ from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import (
@ -29,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_message_to_dict,
request_has_media,
)
@ -105,9 +105,6 @@ class FireworksInferenceAdapter(
fireworks_api_key = self._get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _get_openai_client(self) -> OpenAI:
return OpenAI(base_url=self.config.url, api_key=self._get_api_key())
async def completion(
self,
model_id: str,
@ -272,12 +269,16 @@ class FireworksInferenceAdapter(
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
client = self._get_openai_client()
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
response = client.embeddings.create(
model=model.provider_resource_id, input=contents, **kwargs
assert all(
not content_has_media(content) for content in contents
), "Fireworks does not support media for embeddings"
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]

View file

@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_image_media_to_url,
request_has_media,
)
@ -323,8 +324,12 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(
not content_has_media(content) for content in contents
), "Ollama does not support media for embeddings"
response = await self.client.embed(
model=model.provider_resource_id, input=contents
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
)
embeddings = response["embeddings"]

View file

@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_message_to_dict,
request_has_media,
)
@ -254,8 +255,12 @@ class TogetherInferenceAdapter(
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(
not content_has_media(content) for content in contents
), "Together does not support media for embeddings"
r = self._get_client().embeddings.create(
model=model.provider_resource_id, input=contents
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
)
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_message_to_dict,
request_has_media,
)
@ -206,10 +207,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model = await self.model_store.get_model(model_id)
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert model.model_type == ModelType.embedding_model
assert model.metadata.get("embedding_dimensions")
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(
not content_has_media(content) for content in contents
), "VLLM does not support media for embeddings"
response = self.client.embeddings.create(
model=model.provider_resource_id, input=contents, **kwargs
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]