implement embedding generation in supported inference providers

This commit is contained in:
Dinesh Yeduguru 2024-12-09 12:48:56 -08:00
parent b896be2311
commit e167e9eb93
16 changed files with 383 additions and 29 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from typing import * # noqa: F403
import json
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
@ -448,4 +449,18 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
model = await self.model_store.get_model(model_id)
embeddings = []
for content in contents:
input_text = str(content) if not isinstance(content, str) else content
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=model.provider_resource_id,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
@json_schema_type
class FireworksImplConfig(BaseModel):
url: str = Field(
default="https://api.fireworks.ai/inference",
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server",
)
api_key: Optional[str] = Field(
@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel):
@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference",
"url": "https://api.fireworks.ai/inference/v1",
"api_key": "${env.FIREWORKS_API_KEY}",
}

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId
@ -12,6 +12,7 @@ from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import (
@ -89,19 +90,24 @@ class FireworksInferenceAdapter(
async def shutdown(self) -> None:
pass
def _get_client(self) -> Fireworks:
fireworks_api_key = None
def _get_api_key(self) -> str:
if self.config.api_key is not None:
fireworks_api_key = self.config.api_key
return self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
)
fireworks_api_key = provider_data.fireworks_api_key
return provider_data.fireworks_api_key
def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _get_openai_client(self) -> OpenAI:
return OpenAI(base_url=self.config.url, api_key=self._get_api_key())
async def completion(
self,
model_id: str,
@ -264,4 +270,15 @@ class FireworksInferenceAdapter(
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
model = await self.model_store.get_model(model_id)
client = self._get_openai_client()
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
response = client.embeddings.create(
model=model.provider_resource_id, input=contents, **kwargs
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -321,9 +321,26 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
model = await self.model_store.get_model(model_id)
response = await self.client.embed(
model=model.provider_resource_id, input=contents
)
embeddings = response["embeddings"]
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
# ollama does not have embedding models running. Check if the model is in list of available models.
if model.model_type == ModelType.embedding_model:
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. "
f"Available models: {', '.join(available_models)}"
)
return model
model = await self.register_helper.register_model(model)
models = await self.client.ps()
available_models = [m["model"] for m in models["models"]]

View file

@ -253,4 +253,9 @@ class TogetherInferenceAdapter(
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
model = await self.model_store.get_model(model_id)
r = self._get_client().embeddings.create(
model=model.provider_resource_id, input=contents
)
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -203,4 +203,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
model = await self.model_store.get_model(model_id)
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
response = self.client.embeddings.create(
model=model.provider_resource_id, input=contents, **kwargs
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)