Merge branch 'main' into add-nvidia-inference-adapter

This commit is contained in:
Matthew Farrellee 2024-11-17 15:47:13 -05:00
commit c24f882f31
6 changed files with 51 additions and 22 deletions

View file

@ -7,7 +7,7 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
@ -37,6 +37,8 @@ class ModelInput(CommonModelFields):
provider_id: Optional[str] = None provider_id: Optional[str] = None
provider_model_id: Optional[str] = None provider_model_id: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
@runtime_checkable @runtime_checkable
class Models(Protocol): class Models(Protocol):

View file

@ -48,7 +48,10 @@ class StackRun(Subcommand):
from llama_stack.distribution.build import ImageType from llama_stack.distribution.build import ImageType
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR from llama_stack.distribution.utils.config_dirs import (
BUILDS_BASE_DIR,
DISTRIBS_BASE_DIR,
)
from llama_stack.distribution.utils.exec import run_with_pty from llama_stack.distribution.utils.exec import run_with_pty
if not args.config: if not args.config:
@ -68,6 +71,14 @@ class StackRun(Subcommand):
BUILDS_BASE_DIR / ImageType.docker.value / f"{args.config}-run.yaml" BUILDS_BASE_DIR / ImageType.docker.value / f"{args.config}-run.yaml"
) )
if not config_file.exists() and not args.config.endswith(".yaml"):
# check if it's a build config saved to ~/.llama dir
config_file = Path(
DISTRIBS_BASE_DIR
/ f"llamastack-{args.config}"
/ f"{args.config}-run.yaml"
)
if not config_file.exists(): if not config_file.exists():
self.parser.error( self.parser.error(
f"File {str(config_file)} does not exist. Please run `llama stack build` to generate (and optionally edit) a run.yaml file" f"File {str(config_file)} does not exist. Please run `llama stack build` to generate (and optionally edit) a run.yaml file"

View file

@ -369,12 +369,16 @@ def main(
impl_method = getattr(impl, endpoint.name) impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)( with warnings.catch_warnings():
create_dynamic_typed_route( warnings.filterwarnings(
impl_method, "ignore", category=UserWarning, module="pydantic._internal._fields"
endpoint.method, )
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
)
) )
)
cprint(f"Serving API {api_str}", "white", attrs=["bold"]) cprint(f"Serving API {api_str}", "white", attrs=["bold"])
for endpoint in endpoints: for endpoint in endpoints:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
import io
import json import json
import logging import logging
@ -45,7 +46,12 @@ class FaissIndex(EmbeddingIndex):
self.chunk_by_index = {} self.chunk_by_index = {}
self.kvstore = kvstore self.kvstore = kvstore
self.bank_id = bank_id self.bank_id = bank_id
self.initialize()
@classmethod
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
instance = cls(dimension, kvstore, bank_id)
await instance.initialize()
return instance
async def initialize(self) -> None: async def initialize(self) -> None:
if not self.kvstore: if not self.kvstore:
@ -62,19 +68,20 @@ class FaissIndex(EmbeddingIndex):
for k, v in data["chunk_by_index"].items() for k, v in data["chunk_by_index"].items()
} }
index_bytes = base64.b64decode(data["faiss_index"]) buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
self.index = faiss.deserialize_index(index_bytes) self.index = faiss.deserialize_index(np.loadtxt(buffer, dtype=np.uint8))
async def _save_index(self): async def _save_index(self):
if not self.kvstore or not self.bank_id: if not self.kvstore or not self.bank_id:
return return
index_bytes = faiss.serialize_index(self.index) np_index = faiss.serialize_index(self.index)
buffer = io.BytesIO()
np.savetxt(buffer, np_index)
data = { data = {
"id_by_index": self.id_by_index, "id_by_index": self.id_by_index,
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()}, "chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
"faiss_index": base64.b64encode(index_bytes).decode(), "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
} }
index_key = f"faiss_index:v1::{self.bank_id}" index_key = f"faiss_index:v1::{self.bank_id}"
@ -132,7 +139,10 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
for bank_data in stored_banks: for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data) bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex( index = BankWithIndex(
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore) bank=bank,
index=await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
),
) )
self.cache[bank.identifier] = index self.cache[bank.identifier] = index
@ -158,7 +168,9 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
# Store in cache # Store in cache
index = BankWithIndex( index = BankWithIndex(
bank=memory_bank, bank=memory_bank,
index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore), index=await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
),
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
@ -178,7 +190,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
) -> None: ) -> None:
index = self.cache.get(bank_id) index = self.cache.get(bank_id)
if index is None: if index is None:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}")
await index.insert_documents(documents) await index.insert_documents(documents)

View file

@ -157,7 +157,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=[ pip_packages=[
"openai", "openai",
], ],
module="llama_stack.providers.adapters.inference.nvidia", module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
), ),
), ),

View file

@ -84,7 +84,7 @@ _MODEL_ALIASES = [
] ]
class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference): class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: NVIDIAConfig) -> None: def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models # TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
@ -117,7 +117,7 @@ class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference):
def completion( def completion(
self, self,
model: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
@ -128,14 +128,14 @@ class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()
async def chat_completion( async def chat_completion(
self, self,
model: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
@ -156,7 +156,7 @@ class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference):
request = convert_chat_completion_request( request = convert_chat_completion_request(
request=ChatCompletionRequest( request=ChatCompletionRequest(
model=self.get_provider_model_id(model), model=self.get_provider_model_id(model_id),
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools, tools=tools,