mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
address feedback
This commit is contained in:
parent
e167e9eb93
commit
5821ec9ef3
12 changed files with 61 additions and 76 deletions
|
@ -205,12 +205,10 @@ API responses, specify the adapter here.
|
|||
def remote_provider_spec(
|
||||
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
|
||||
) -> RemoteProviderSpec:
|
||||
if api_dependencies is None:
|
||||
api_dependencies = []
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
adapter=adapter,
|
||||
api_dependencies=api_dependencies,
|
||||
api_dependencies=api_dependencies or [],
|
||||
)
|
||||
|
|
|
@ -84,7 +84,7 @@ class MetaReferenceInferenceImpl(
|
|||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.model_registry_helper.register_model(model)
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
self._get_embedding_model(model.provider_resource_id)
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
return model
|
||||
|
||||
async def completion(
|
||||
|
|
|
@ -48,7 +48,7 @@ class SentenceTransformersInferenceImpl(
|
|||
)
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
_ = self._get_embedding_model(model.provider_resource_id)
|
||||
_ = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
|
@ -63,7 +63,7 @@ class SentenceTransformersInferenceImpl(
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncGenerator]:
|
||||
raise NotImplementedError("Sentence transformers don't support completion")
|
||||
raise ValueError("Sentence transformers don't support completion")
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
@ -77,4 +77,4 @@ class SentenceTransformersInferenceImpl(
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError("Sentence transformers don't support chat completion")
|
||||
raise ValueError("Sentence transformers don't support chat completion")
|
||||
|
|
|
@ -20,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
|
||||
model_aliases = [
|
||||
|
@ -452,7 +454,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
model = await self.model_store.get_model(model_id)
|
||||
embeddings = []
|
||||
for content in contents:
|
||||
input_text = str(content) if not isinstance(content, str) else content
|
||||
assert not content_has_media(
|
||||
content
|
||||
), "Bedrock does not support media for embeddings"
|
||||
input_text = interleaved_text_media_as_str(content)
|
||||
input_body = {"inputText": input_text}
|
||||
body = json.dumps(input_body)
|
||||
response = self.client.invoke_model(
|
||||
|
|
|
@ -12,7 +12,6 @@ from llama_models.datatypes import CoreModelId
|
|||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from openai import OpenAI
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
|
@ -29,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
@ -105,9 +105,6 @@ class FireworksInferenceAdapter(
|
|||
fireworks_api_key = self._get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
def _get_openai_client(self) -> OpenAI:
|
||||
return OpenAI(base_url=self.config.url, api_key=self._get_api_key())
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -272,12 +269,16 @@ class FireworksInferenceAdapter(
|
|||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
client = self._get_openai_client()
|
||||
kwargs = {}
|
||||
if model.metadata.get("embedding_dimensions"):
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
response = client.embeddings.create(
|
||||
model=model.provider_resource_id, input=contents, **kwargs
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "Fireworks does not support media for embeddings"
|
||||
response = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
|
|
|
@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
@ -323,8 +324,12 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "Ollama does not support media for embeddings"
|
||||
response = await self.client.embed(
|
||||
model=model.provider_resource_id, input=contents
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = response["embeddings"]
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
@ -254,8 +255,12 @@ class TogetherInferenceAdapter(
|
|||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "Together does not support media for embeddings"
|
||||
r = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id, input=contents
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = [item.embedding for item in r.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
@ -206,10 +207,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
if model.metadata.get("embedding_dimensions"):
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
assert model.model_type == ModelType.embedding_model
|
||||
assert model.metadata.get("embedding_dimensions")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "VLLM does not support media for embeddings"
|
||||
response = self.client.embeddings.create(
|
||||
model=model.provider_resource_id, input=contents, **kwargs
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
|
|
|
@ -84,24 +84,3 @@ def pytest_generate_tests(metafunc):
|
|||
):
|
||||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
||||
|
||||
if "embedding_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--embedding-model")
|
||||
if not model:
|
||||
raise ValueError(
|
||||
"No embedding model specified. Please provide a valid embedding model."
|
||||
)
|
||||
params = [pytest.param(model, id="")]
|
||||
|
||||
metafunc.parametrize("embedding_model", params, indirect=True)
|
||||
|
||||
if "embedding_stack" in metafunc.fixturenames:
|
||||
fixtures = INFERENCE_FIXTURES
|
||||
if filtered_stacks := get_provider_fixture_overrides(
|
||||
metafunc.config,
|
||||
{
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
},
|
||||
):
|
||||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||
metafunc.parametrize("embedding_stack", fixtures, indirect=True)
|
||||
|
|
|
@ -37,13 +37,6 @@ def inference_model(request):
|
|||
return request.config.getoption("--inference-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def embedding_model(request):
|
||||
if hasattr(request, "param"):
|
||||
return request.param
|
||||
return request.config.getoption("--embedding-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
@ -239,31 +232,21 @@ INFERENCE_FIXTURES = [
|
|||
async def inference_stack(request, inference_model):
|
||||
fixture_name = request.param
|
||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.inference],
|
||||
{"inference": inference_fixture.providers},
|
||||
inference_fixture.provider_data,
|
||||
models=[ModelInput(model_id=inference_model)],
|
||||
)
|
||||
model_type = ModelType.llm
|
||||
metadata = {}
|
||||
if os.getenv("EMBEDDING_DIMENSION"):
|
||||
model_type = ModelType.embedding_model
|
||||
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
|
||||
|
||||
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def embedding_stack(request, embedding_model):
|
||||
fixture_name = request.param
|
||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.inference],
|
||||
{"inference": inference_fixture.providers},
|
||||
inference_fixture.provider_data,
|
||||
models=[
|
||||
ModelInput(
|
||||
model_id=embedding_model,
|
||||
model_type=ModelType.embedding_model,
|
||||
metadata={
|
||||
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||
},
|
||||
model_id=inference_model,
|
||||
model_type=model_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
|
|
@ -14,15 +14,15 @@ from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
|||
|
||||
class TestEmbeddings:
|
||||
@pytest.mark.asyncio
|
||||
async def test_embeddings(self, embedding_model, embedding_stack):
|
||||
inference_impl, models_impl = embedding_stack
|
||||
model = await models_impl.get_model(embedding_model)
|
||||
async def test_embeddings(self, inference_model, inference_stack):
|
||||
inference_impl, models_impl = inference_stack
|
||||
model = await models_impl.get_model(inference_model)
|
||||
|
||||
if model.model_type != ModelType.embedding_model:
|
||||
pytest.skip("This test is only applicable for embedding models")
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
model_id=embedding_model,
|
||||
model_id=inference_model,
|
||||
contents=["Hello, world!"],
|
||||
)
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
|
@ -35,9 +35,9 @@ class TestEmbeddings:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_embeddings(self, embedding_model, embedding_stack):
|
||||
inference_impl, models_impl = embedding_stack
|
||||
model = await models_impl.get_model(embedding_model)
|
||||
async def test_batch_embeddings(self, inference_model, inference_stack):
|
||||
inference_impl, models_impl = inference_stack
|
||||
model = await models_impl.get_model(inference_model)
|
||||
|
||||
if model.model_type != ModelType.embedding_model:
|
||||
pytest.skip("This test is only applicable for embedding models")
|
||||
|
@ -45,7 +45,7 @@ class TestEmbeddings:
|
|||
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
model_id=embedding_model,
|
||||
model_id=inference_model,
|
||||
contents=texts,
|
||||
)
|
||||
|
||||
|
|
|
@ -26,11 +26,13 @@ class SentenceTransformerEmbeddingMixin:
|
|||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._get_embedding_model(model.provider_resource_id)
|
||||
embedding_model = self._load_sentence_transformer_model(
|
||||
model.provider_resource_id
|
||||
)
|
||||
embeddings = embedding_model.encode(contents)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
def _get_embedding_model(self, model: str) -> "SentenceTransformer":
|
||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue