mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Implement embeddings for ollama
This commit is contained in:
parent
0cec86453b
commit
302555b11a
4 changed files with 65 additions and 1 deletions
|
@ -195,6 +195,15 @@ class BatchChatCompletionResponse(BaseModel):
|
||||||
batch: List[ChatCompletionResponse]
|
batch: List[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EmbeddingRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
contents: List[InterleavedTextMedia]
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
truncate: Optional[bool] = True
|
||||||
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EmbeddingsResponse(BaseModel):
|
class EmbeddingsResponse(BaseModel):
|
||||||
embeddings: List[List[float]]
|
embeddings: List[List[float]]
|
||||||
|
@ -241,4 +250,7 @@ class Inference(Protocol):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
truncate: Optional[bool] = True,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> EmbeddingsResponse: ...
|
) -> EmbeddingsResponse: ...
|
||||||
|
|
|
@ -128,10 +128,16 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
truncate: Optional[bool] = True,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
return await self.routing_table.get_provider_impl(model).embeddings(
|
return await self.routing_table.get_provider_impl(model).embeddings(
|
||||||
model=model,
|
model=model,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
truncate=truncate,
|
||||||
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -234,5 +234,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
truncate: Optional[bool] = True,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
|
||||||
|
request = EmbeddingRequest(
|
||||||
|
model=model,
|
||||||
|
contents=contents,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
truncate=truncate,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
return await self._embeddings(request)
|
||||||
|
|
||||||
|
async def _embeddings(self, request: EmbeddingRequest) -> EmbeddingsResponse:
|
||||||
|
params = self._get_params_for_embeddings(request)
|
||||||
|
r = await self.client.embed(**params)
|
||||||
|
return EmbeddingsResponse(embeddings=r["embeddings"])
|
||||||
|
|
||||||
|
def _get_params_for_embeddings(self, request: EmbeddingRequest) -> dict:
|
||||||
|
sampling_options = get_sampling_options(request)
|
||||||
|
return {
|
||||||
|
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
||||||
|
"input": request.contents,
|
||||||
|
"options": sampling_options,
|
||||||
|
"truncate": request.truncate,
|
||||||
|
}
|
||||||
|
|
|
@ -170,6 +170,27 @@ async def test_completion(inference_settings):
|
||||||
assert last.stop_reason == StopReason.out_of_tokens
|
assert last.stop_reason == StopReason.out_of_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed(inference_settings):
|
||||||
|
inference_impl = inference_settings["impl"]
|
||||||
|
params = inference_settings["common_params"]
|
||||||
|
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||||
|
if provider.__provider_spec__.provider_type not in ("remote::ollama",):
|
||||||
|
pytest.skip("Other inference providers don't support completion() yet")
|
||||||
|
|
||||||
|
response = await inference_impl.embeddings(
|
||||||
|
contents=["Roses are red"],
|
||||||
|
model=params["model"],
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
max_tokens=50,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
|
assert len(response.embeddings) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||||
inference_impl = inference_settings["impl"]
|
inference_impl = inference_settings["impl"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue