Make embedding generation go through inference (#606)

This PR does the following:
1) adds the ability to generate embeddings in all supported inference
providers.
2) Moves all the memory providers to use the inference API and improved
the memory tests to setup the inference stack correctly and use the
embedding models

This is a merge from #589 and #598
This commit is contained in:
Dinesh Yeduguru 2024-12-12 11:47:50 -08:00 committed by GitHub
parent a14785af46
commit 96e158eaac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 677 additions and 156 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId
@ -28,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_message_to_dict,
request_has_media,
)
@ -89,17 +90,19 @@ class FireworksInferenceAdapter(
async def shutdown(self) -> None:
pass
def _get_client(self) -> Fireworks:
fireworks_api_key = None
def _get_api_key(self) -> str:
if self.config.api_key is not None:
fireworks_api_key = self.config.api_key
return self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
)
fireworks_api_key = provider_data.fireworks_api_key
return provider_data.fireworks_api_key
def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key()
return Fireworks(api_key=fireworks_api_key)
async def completion(
@ -264,4 +267,19 @@ class FireworksInferenceAdapter(
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
model = await self.model_store.get_model(model_id)
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(
not content_has_media(content) for content in contents
), "Fireworks does not support media for embeddings"
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)