chore: made together async and faiss index search as async

This commit is contained in:
sarthakdeshpande 2025-03-09 17:35:27 +05:30
parent b5a505aa8d
commit 501467846d
2 changed files with 21 additions and 32 deletions

View file

@ -8,6 +8,7 @@ import base64
import io import io
import json import json
import logging import logging
import asyncio
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import faiss import faiss
@ -99,7 +100,9 @@ class FaissIndex(EmbeddingIndex):
await self._save_index() await self._save_index()
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
distances, indices = self.index.search(embedding.reshape(1, -1).astype(np.float32), k) distances, indices = await asyncio.to_thread(
self.index.search, embedding.reshape(1, -1).astype(np.float32), k
)
chunks = [] chunks = []
scores = [] scores = []

View file

@ -1,12 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from together import Together from together import AsyncTogether
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -91,7 +85,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
else: else:
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
def _get_client(self) -> Together: async def _get_client(self) -> AsyncTogether:
together_api_key = None together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key: if config_api_key:
@ -103,23 +97,18 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}' 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
) )
together_api_key = provider_data.together_api_key together_api_key = provider_data.together_api_key
return Together(api_key=together_api_key) return AsyncTogether(api_key=together_api_key)
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = self._get_client().completions.create(**params) client = await self._get_client()
r = await client.completions.create(**params)
return process_completion_response(r) return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
client = await self._get_client()
# if we shift to TogetherAsyncClient, we won't need this wrapper stream = await client.completions.create(**params)
async def _to_async_generator():
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
@ -184,25 +173,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
client = await self._get_client()
if "messages" in params: if "messages" in params:
r = self._get_client().chat.completions.create(**params) r = await client.chat.completions.create(**params)
else: else:
r = self._get_client().completions.create(**params) r = await client.completions.create(**params)
return process_chat_completion_response(r, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
client = await self._get_client()
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
if "messages" in params: if "messages" in params:
s = self._get_client().chat.completions.create(**params) stream = await client.chat.completions.create(**params)
else: else:
s = self._get_client().completions.create(**params) stream = await client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
@ -240,7 +225,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (
"Together does not support media for embeddings" "Together does not support media for embeddings"
) )
r = self._get_client().embeddings.create( client = await self._get_client()
r = await client.embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
) )