mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-30 11:50:14 +00:00
chore: Add OpenAI compatibility for Ollama embeddings
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
fef670b024
commit
15042c6c31
2 changed files with 39 additions and 3 deletions
|
@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -46,6 +45,9 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
|
@ -386,7 +388,42 @@ class OllamaInferenceAdapter(
|
|||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model_obj = await self._get_model(model)
|
||||
if model_obj.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model {model} is not an embedding model")
|
||||
|
||||
params = {
|
||||
"model": model_obj.provider_resource_id,
|
||||
"input": input,
|
||||
}
|
||||
|
||||
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
|
||||
if encoding_format is not None:
|
||||
params["encoding_format"] = encoding_format
|
||||
if dimensions is not None:
|
||||
params["dimensions"] = str(dimensions)
|
||||
if user is not None:
|
||||
params["user"] = user
|
||||
|
||||
response = await self.openai_client.embeddings.create(**params)
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
|
|
@ -51,7 +51,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
|
|||
"remote::runpod",
|
||||
"remote::sambanova",
|
||||
"remote::tgi",
|
||||
"remote::ollama",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue