implement embedding generation in supported inference providers (#589)

This PR adds the ability to generate embeddings in all supported
inference providers.

```
pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py -k "bedrock" --inference-model="amazon.titan-embed-text-v2:0"  --env EMBEDDING_DIMENSION=1024

 pytest -v -s -k "vllm"  --inferrence-model="intfloat/e5-mistral-7b-instruct"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=4096  --env VLLM_URL="http://localhost:9798/v1"

pytest -v -s --inference-model="nomic-ai/nomic-embed-text-v1.5"  llama_stack/providers/tests/inference/test_embeddings.py  -k "fireworks"  --env FIREWORKS_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=128

pytest -v -s --inference-model="togethercomputer/m2-bert-80M-2k-retrieval"  llama_stack/providers/tests/inference/test_embeddings.py  -k "together"  --env TOGETHER_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=768

pytest -v -s -k "ollama"  --inference-model="all-minilm:v8"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384

 torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="sentence-transformers/all-MiniLM-L6-v2"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384

```
This commit is contained in:
Dinesh Yeduguru 2024-12-12 11:17:39 -08:00
parent 6a23f24ee0
commit d362d2d740
32 changed files with 597 additions and 143 deletions

View file

@ -9,9 +9,9 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
)
@ -47,6 +47,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
# If embedding dimension is set, use the 8B model for testing
if os.getenv("EMBEDDING_DIMENSION"):
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
return ProviderFixture(
providers=[
@ -85,7 +88,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
if "Llama3.1-8B-Instruct" in inference_model:
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
return ProviderFixture(
@ -232,11 +235,23 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding_model
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[ModelInput(model_id=inference_model)],
models=[
ModelInput(
model_id=inference_model,
model_type=model_type,
metadata=metadata,
)
],
)
return test_stack.impls[Api.inference], test_stack.impls[Api.models]