Merge branch 'main' into add-nvidia-inference-adapter

This commit is contained in:
Matthew Farrellee 2024-11-21 06:49:13 -05:00
commit 5fbfb9d854
92 changed files with 2145 additions and 678 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model
@ -37,8 +37,10 @@ class MetaReferenceInferenceConfig(BaseModel):
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
@ -54,6 +56,7 @@ class MetaReferenceInferenceConfig(BaseModel):
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
**kwargs,
) -> Dict[str, Any]:
return {
"model": model,
@ -64,3 +67,16 @@ class MetaReferenceInferenceConfig(BaseModel):
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig
@classmethod
def sample_run_config(
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
**kwargs,
) -> Dict[str, Any]:
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
config["quantization"] = {
"type": "fp8",
}
return config

View file

@ -37,19 +37,22 @@ class VLLMConfig(BaseModel):
@classmethod
def sample_run_config(cls):
return {
"model": "${env.VLLM_INFERENCE_MODEL:Llama3.2-3B-Instruct}",
"tensor_parallel_size": "${env.VLLM_TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
"enforce_eager": "${env.VLLM_ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.VLLM_GPU_MEMORY_UTILIZATION:0.3}",
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.MAX_TOKENS:4096}",
"enforce_eager": "${env.ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.7}",
}
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)

View file

@ -4,11 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
@json_schema_type
class BedrockConfig(BedrockBaseConfig):
pass

View file

@ -37,6 +37,18 @@ class InferenceEndpointImplConfig(BaseModel):
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)
@classmethod
def sample_run_config(
cls,
endpoint_name: str = "${env.INFERENCE_ENDPOINT_NAME}",
api_token: str = "${env.HF_API_TOKEN}",
**kwargs,
):
return {
"endpoint_name": endpoint_name,
"api_token": api_token,
}
@json_schema_type
class InferenceAPIImplConfig(BaseModel):
@ -47,3 +59,15 @@ class InferenceAPIImplConfig(BaseModel):
default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)
@classmethod
def sample_run_config(
cls,
repo: str = "${env.INFERENCE_MODEL}",
api_token: str = "${env.HF_API_TOKEN}",
**kwargs,
):
return {
"huggingface_repo": repo,
"api_token": api_token,
}

View file

@ -147,9 +147,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
@ -159,8 +157,20 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
collection = await self.client.get_collection(bank_id)
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
self.cache[bank_id] = index
return index

View file

@ -201,10 +201,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
async def query_documents(
@ -213,8 +210,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index

View file

@ -11,7 +11,6 @@ import pytest
#
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
# -m "meta_reference"
# --env TOGETHER_API_KEY=<your_api_key>
class TestModelRegistration:

View file

@ -5,11 +5,9 @@
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class BedrockBaseConfig(BaseModel):
aws_access_key_id: Optional[str] = Field(
default=None,
@ -57,3 +55,7 @@ class BedrockBaseConfig(BaseModel):
default=3600,
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
)
@classmethod
def sample_run_config(cls, **kwargs):
return {}

View file

@ -22,9 +22,9 @@ def is_supported_safety_model(model: Model) -> bool:
]
def supported_inference_models() -> List[str]:
def supported_inference_models() -> List[Model]:
return [
m.descriptor()
m
for m in all_registered_models()
if (
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}

View file

@ -178,7 +178,9 @@ def chat_completion_request_to_messages(
cprint(f"Could not resolve model {llama_model}", color="red")
return request.messages
if model.descriptor() not in supported_inference_models():
allowed_models = supported_inference_models()
descriptors = [m.descriptor() for m in allowed_models]
if model.descriptor() not in descriptors:
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
return request.messages