mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 10:42:36 +00:00
Merge branch 'main' into add-nvidia-inference-adapter
This commit is contained in:
commit
c24f882f31
6 changed files with 51 additions and 22 deletions
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
|
@ -45,7 +46,12 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.chunk_by_index = {}
|
||||
self.kvstore = kvstore
|
||||
self.bank_id = bank_id
|
||||
self.initialize()
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
|
||||
instance = cls(dimension, kvstore, bank_id)
|
||||
await instance.initialize()
|
||||
return instance
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self.kvstore:
|
||||
|
|
@ -62,19 +68,20 @@ class FaissIndex(EmbeddingIndex):
|
|||
for k, v in data["chunk_by_index"].items()
|
||||
}
|
||||
|
||||
index_bytes = base64.b64decode(data["faiss_index"])
|
||||
self.index = faiss.deserialize_index(index_bytes)
|
||||
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
||||
self.index = faiss.deserialize_index(np.loadtxt(buffer, dtype=np.uint8))
|
||||
|
||||
async def _save_index(self):
|
||||
if not self.kvstore or not self.bank_id:
|
||||
return
|
||||
|
||||
index_bytes = faiss.serialize_index(self.index)
|
||||
|
||||
np_index = faiss.serialize_index(self.index)
|
||||
buffer = io.BytesIO()
|
||||
np.savetxt(buffer, np_index)
|
||||
data = {
|
||||
"id_by_index": self.id_by_index,
|
||||
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
|
||||
"faiss_index": base64.b64encode(index_bytes).decode(),
|
||||
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
||||
}
|
||||
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
|
|
@ -132,7 +139,10 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
for bank_data in stored_banks:
|
||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||
index = BankWithIndex(
|
||||
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore)
|
||||
bank=bank,
|
||||
index=await FaissIndex.create(
|
||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
|
||||
),
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
|
|
@ -158,7 +168,9 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
# Store in cache
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
|
||||
index=await FaissIndex.create(
|
||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
|
||||
),
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
|
|
@ -178,7 +190,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
) -> None:
|
||||
index = self.cache.get(bank_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}")
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.adapters.inference.nvidia",
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
),
|
||||
),
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ _MODEL_ALIASES = [
|
|||
]
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||
|
|
@ -117,7 +117,7 @@ class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
|
|
@ -128,14 +128,14 @@ class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
|
|
@ -156,7 +156,7 @@ class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
request = convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=self.get_provider_model_id(model),
|
||||
model=self.get_provider_model_id(model_id),
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue