mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 16:29:54 +00:00
Merge branch 'main' into add-nvidia-inference-adapter
This commit is contained in:
commit
5fbfb9d854
92 changed files with 2145 additions and 678 deletions
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
|
@ -37,8 +37,10 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
descriptors = [m.descriptor() for m in permitted_models]
|
||||
repos = [m.huggingface_repo for m in permitted_models]
|
||||
if model not in (descriptors + repos):
|
||||
model_list = "\n\t".join(repos)
|
||||
raise ValueError(
|
||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||
)
|
||||
|
|
@ -54,6 +56,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": model,
|
||||
|
|
@ -64,3 +67,16 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
|
||||
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
|
||||
quantization: QuantizationConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
|
||||
config["quantization"] = {
|
||||
"type": "fp8",
|
||||
}
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -37,19 +37,22 @@ class VLLMConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
return {
|
||||
"model": "${env.VLLM_INFERENCE_MODEL:Llama3.2-3B-Instruct}",
|
||||
"tensor_parallel_size": "${env.VLLM_TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
|
||||
"enforce_eager": "${env.VLLM_ENFORCE_EAGER:False}",
|
||||
"gpu_memory_utilization": "${env.VLLM_GPU_MEMORY_UTILIZATION:0.3}",
|
||||
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
"enforce_eager": "${env.ENFORCE_EAGER:False}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.7}",
|
||||
}
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
|
||||
descriptors = [m.descriptor() for m in permitted_models]
|
||||
repos = [m.huggingface_repo for m in permitted_models]
|
||||
if model not in (descriptors + repos):
|
||||
model_list = "\n\t".join(repos)
|
||||
raise ValueError(
|
||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,11 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BedrockConfig(BedrockBaseConfig):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -37,6 +37,18 @@ class InferenceEndpointImplConfig(BaseModel):
|
|||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
endpoint_name: str = "${env.INFERENCE_ENDPOINT_NAME}",
|
||||
api_token: str = "${env.HF_API_TOKEN}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"endpoint_name": endpoint_name,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceAPIImplConfig(BaseModel):
|
||||
|
|
@ -47,3 +59,15 @@ class InferenceAPIImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
repo: str = "${env.INFERENCE_MODEL}",
|
||||
api_token: str = "${env.HF_API_TOKEN}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"huggingface_repo": repo,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -147,9 +147,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
|
|
@ -159,8 +157,20 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
|
||||
return await index.query_documents(query, params)
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
|
||||
collection = await self.client.get_collection(bank_id)
|
||||
if not collection:
|
||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -201,10 +201,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
|
|
@ -213,8 +210,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
return await index.query_documents(query, params)
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import pytest
|
|||
#
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
|
||||
# -m "meta_reference"
|
||||
# --env TOGETHER_API_KEY=<your_api_key>
|
||||
|
||||
|
||||
class TestModelRegistration:
|
||||
|
|
|
|||
|
|
@ -5,11 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BedrockBaseConfig(BaseModel):
|
||||
aws_access_key_id: Optional[str] = Field(
|
||||
default=None,
|
||||
|
|
@ -57,3 +55,7 @@ class BedrockBaseConfig(BaseModel):
|
|||
default=3600,
|
||||
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs):
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ def is_supported_safety_model(model: Model) -> bool:
|
|||
]
|
||||
|
||||
|
||||
def supported_inference_models() -> List[str]:
|
||||
def supported_inference_models() -> List[Model]:
|
||||
return [
|
||||
m.descriptor()
|
||||
m
|
||||
for m in all_registered_models()
|
||||
if (
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
|
|
|
|||
|
|
@ -178,7 +178,9 @@ def chat_completion_request_to_messages(
|
|||
cprint(f"Could not resolve model {llama_model}", color="red")
|
||||
return request.messages
|
||||
|
||||
if model.descriptor() not in supported_inference_models():
|
||||
allowed_models = supported_inference_models()
|
||||
descriptors = [m.descriptor() for m in allowed_models]
|
||||
if model.descriptor() not in descriptors:
|
||||
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
|
||||
return request.messages
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue