nuke updates

This commit is contained in:
Dinesh Yeduguru 2024-11-14 17:05:09 -08:00
parent 690e525a36
commit aa93eeb2b7
15 changed files with 15 additions and 429 deletions

View file

@ -45,8 +45,6 @@ class Api(Enum):
class ModelsProtocolPrivate(Protocol):
async def register_model(self, model: Model) -> None: ...
async def update_model(self, model: Model) -> None: ...
async def unregister_model(self, model_id: str) -> None: ...
@ -61,8 +59,6 @@ class MemoryBanksProtocolPrivate(Protocol):
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
async def update_memory_bank(self, memory_bank: MemoryBank) -> None: ...
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...
@ -107,6 +103,7 @@ class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
@ -179,12 +176,10 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
adapter: AdapterSpec = Field(
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
API responses, specify the adapter here.
""",
)
@ -194,38 +189,21 @@ as being "Llama Stack compatible"
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return "llama_stack.distribution.client"
return self.adapter.module
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
return self.adapter.pip_packages
@property
def provider_data_validator(self) -> Optional[str]:
if self.adapter:
return self.adapter.provider_data_validator
return None
return self.adapter.provider_data_validator
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
adapter=adapter,
)

View file

@ -71,9 +71,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
async def update_model(self, model: Model) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -108,9 +108,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
return VLLMSamplingParams(**kwargs)
async def update_model(self, model: Model) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -48,10 +48,9 @@ class FaissIndex(EmbeddingIndex):
self.initialize()
async def initialize(self) -> None:
if not self.kvstore or not self.bank_id:
if not self.kvstore:
return
# Load existing index data from kvstore
index_key = f"faiss_index:v1::{self.bank_id}"
stored_data = await self.kvstore.get(index_key)
@ -63,7 +62,6 @@ class FaissIndex(EmbeddingIndex):
for k, v in data["chunk_by_index"].items()
}
# Load FAISS index
index_bytes = base64.b64decode(data["faiss_index"])
self.index = faiss.deserialize_index(index_bytes)
@ -71,17 +69,14 @@ class FaissIndex(EmbeddingIndex):
if not self.kvstore or not self.bank_id:
return
# Serialize FAISS index
index_bytes = faiss.serialize_index(self.index)
# Prepare data for storage
data = {
"id_by_index": self.id_by_index,
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
"faiss_index": base64.b64encode(index_bytes).decode(),
}
# Store in kvstore
index_key = f"faiss_index:v1::{self.bank_id}"
await self.kvstore.set(key=index_key, value=json.dumps(data))
@ -175,15 +170,6 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
del self.cache[memory_bank_id]
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
# Not possible to update the index in place, so we delete and recreate
await self.cache[memory_bank.identifier].index.delete()
self.cache[memory_bank.identifier] = BankWithIndex(
bank=memory_bank,
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore),
)
async def insert_documents(
self,
bank_id: str,

View file

@ -93,9 +93,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def update_model(self, model: Model) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -69,9 +69,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def update_model(self, model: Model) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -58,9 +58,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def update_model(self, model: Model) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -141,10 +141,6 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
await self.unregister_memory_bank(memory_bank.identifier)
await self.register_memory_bank(memory_bank)
async def insert_documents(
self,
bank_id: str,

View file

@ -184,10 +184,6 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
await self.unregister_memory_bank(memory_bank.identifier)
await self.register_memory_bank(memory_bank)
async def list_memory_banks(self) -> List[MemoryBank]:
banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks: