address feedback

This commit is contained in:
Dinesh Yeduguru 2024-12-11 16:24:37 -08:00
parent e167e9eb93
commit 5821ec9ef3
12 changed files with 61 additions and 76 deletions

View file

@ -205,12 +205,10 @@ API responses, specify the adapter here.
def remote_provider_spec(
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
) -> RemoteProviderSpec:
if api_dependencies is None:
api_dependencies = []
return RemoteProviderSpec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
adapter=adapter,
api_dependencies=api_dependencies,
api_dependencies=api_dependencies or [],
)

View file

@ -84,7 +84,7 @@ class MetaReferenceInferenceImpl(
async def register_model(self, model: Model) -> Model:
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding_model:
self._get_embedding_model(model.provider_resource_id)
self._load_sentence_transformer_model(model.provider_resource_id)
return model
async def completion(

View file

@ -48,7 +48,7 @@ class SentenceTransformersInferenceImpl(
)
async def register_model(self, model: Model) -> None:
_ = self._get_embedding_model(model.provider_resource_id)
_ = self._load_sentence_transformer_model(model.provider_resource_id)
return model
async def unregister_model(self, model_id: str) -> None:
@ -63,7 +63,7 @@ class SentenceTransformersInferenceImpl(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncGenerator]:
raise NotImplementedError("Sentence transformers don't support completion")
raise ValueError("Sentence transformers don't support completion")
async def chat_completion(
self,
@ -77,4 +77,4 @@ class SentenceTransformersInferenceImpl(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError("Sentence transformers don't support chat completion")
raise ValueError("Sentence transformers don't support chat completion")

View file

@ -20,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
model_aliases = [
@ -452,7 +454,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
model = await self.model_store.get_model(model_id)
embeddings = []
for content in contents:
input_text = str(content) if not isinstance(content, str) else content
assert not content_has_media(
content
), "Bedrock does not support media for embeddings"
input_text = interleaved_text_media_as_str(content)
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(

View file

@ -12,7 +12,6 @@ from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import (
@ -29,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_message_to_dict,
request_has_media,
)
@ -105,9 +105,6 @@ class FireworksInferenceAdapter(
fireworks_api_key = self._get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _get_openai_client(self) -> OpenAI:
return OpenAI(base_url=self.config.url, api_key=self._get_api_key())
async def completion(
self,
model_id: str,
@ -272,12 +269,16 @@ class FireworksInferenceAdapter(
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
client = self._get_openai_client()
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
response = client.embeddings.create(
model=model.provider_resource_id, input=contents, **kwargs
assert all(
not content_has_media(content) for content in contents
), "Fireworks does not support media for embeddings"
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]

View file

@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_image_media_to_url,
request_has_media,
)
@ -323,8 +324,12 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(
not content_has_media(content) for content in contents
), "Ollama does not support media for embeddings"
response = await self.client.embed(
model=model.provider_resource_id, input=contents
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
)
embeddings = response["embeddings"]

View file

@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_message_to_dict,
request_has_media,
)
@ -254,8 +255,12 @@ class TogetherInferenceAdapter(
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(
not content_has_media(content) for content in contents
), "Together does not support media for embeddings"
r = self._get_client().embeddings.create(
model=model.provider_resource_id, input=contents
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
)
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_message_to_dict,
request_has_media,
)
@ -206,10 +207,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model = await self.model_store.get_model(model_id)
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert model.model_type == ModelType.embedding_model
assert model.metadata.get("embedding_dimensions")
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(
not content_has_media(content) for content in contents
), "VLLM does not support media for embeddings"
response = self.client.embeddings.create(
model=model.provider_resource_id, input=contents, **kwargs
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]

View file

@ -84,24 +84,3 @@ def pytest_generate_tests(metafunc):
):
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
metafunc.parametrize("inference_stack", fixtures, indirect=True)
if "embedding_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--embedding-model")
if not model:
raise ValueError(
"No embedding model specified. Please provide a valid embedding model."
)
params = [pytest.param(model, id="")]
metafunc.parametrize("embedding_model", params, indirect=True)
if "embedding_stack" in metafunc.fixturenames:
fixtures = INFERENCE_FIXTURES
if filtered_stacks := get_provider_fixture_overrides(
metafunc.config,
{
"inference": INFERENCE_FIXTURES,
},
):
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
metafunc.parametrize("embedding_stack", fixtures, indirect=True)

View file

@ -37,13 +37,6 @@ def inference_model(request):
return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session")
def embedding_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--embedding-model", None)
@pytest.fixture(scope="session")
def inference_remote() -> ProviderFixture:
return remote_stack_fixture()
@ -239,31 +232,21 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[ModelInput(model_id=inference_model)],
)
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding_model
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
@pytest_asyncio.fixture(scope="session")
async def embedding_stack(request, embedding_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[
ModelInput(
model_id=embedding_model,
model_type=ModelType.embedding_model,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
model_id=inference_model,
model_type=model_type,
metadata=metadata,
)
],
)

View file

@ -14,15 +14,15 @@ from llama_stack.apis.inference import EmbeddingsResponse, ModelType
class TestEmbeddings:
@pytest.mark.asyncio
async def test_embeddings(self, embedding_model, embedding_stack):
inference_impl, models_impl = embedding_stack
model = await models_impl.get_model(embedding_model)
async def test_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding_model:
pytest.skip("This test is only applicable for embedding models")
response = await inference_impl.embeddings(
model_id=embedding_model,
model_id=inference_model,
contents=["Hello, world!"],
)
assert isinstance(response, EmbeddingsResponse)
@ -35,9 +35,9 @@ class TestEmbeddings:
)
@pytest.mark.asyncio
async def test_batch_embeddings(self, embedding_model, embedding_stack):
inference_impl, models_impl = embedding_stack
model = await models_impl.get_model(embedding_model)
async def test_batch_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding_model:
pytest.skip("This test is only applicable for embedding models")
@ -45,7 +45,7 @@ class TestEmbeddings:
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
response = await inference_impl.embeddings(
model_id=embedding_model,
model_id=inference_model,
contents=texts,
)

View file

@ -26,11 +26,13 @@ class SentenceTransformerEmbeddingMixin:
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._get_embedding_model(model.provider_resource_id)
embedding_model = self._load_sentence_transformer_model(
model.provider_resource_id
)
embeddings = embedding_model.encode(contents)
return EmbeddingsResponse(embeddings=embeddings)
def _get_embedding_model(self, model: str) -> "SentenceTransformer":
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)