implement embedding generation in supported inference providers

This commit is contained in:
Dinesh Yeduguru 2024-12-09 12:48:56 -08:00
parent b896be2311
commit e167e9eb93
16 changed files with 383 additions and 29 deletions

View file

@ -321,9 +321,26 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
model = await self.model_store.get_model(model_id)
response = await self.client.embed(
model=model.provider_resource_id, input=contents
)
embeddings = response["embeddings"]
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
# ollama does not have embedding models running. Check if the model is in list of available models.
if model.model_type == ModelType.embedding_model:
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. "
f"Available models: {', '.join(available_models)}"
)
return model
model = await self.register_helper.register_model(model)
models = await self.client.ps()
available_models = [m["model"] for m in models["models"]]