mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 09:09:48 +00:00
Merge branch 'main' into add-nvidia-inference-adapter
This commit is contained in:
commit
43262df033
399 changed files with 17826 additions and 10490 deletions
|
|
@ -271,7 +271,7 @@ class Session(BaseModel):
|
|||
turns: List[Turn]
|
||||
started_at: datetime
|
||||
|
||||
memory_bank: Optional[MemoryBankDef] = None
|
||||
memory_bank: Optional[MemoryBank] = None
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class PaginatedRowsResult(BaseModel):
|
|||
|
||||
|
||||
class DatasetStore(Protocol):
|
||||
def get_dataset(self, identifier: str) -> DatasetDefWithProvider: ...
|
||||
def get_dataset(self, dataset_id: str) -> Dataset: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
|
|
@ -13,16 +13,11 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatasetDef(BaseModel):
|
||||
identifier: str = Field(
|
||||
description="A unique name for the dataset",
|
||||
)
|
||||
dataset_schema: Dict[str, ParamType] = Field(
|
||||
description="The schema definition for this dataset",
|
||||
)
|
||||
class CommonDatasetFields(BaseModel):
|
||||
dataset_schema: Dict[str, ParamType]
|
||||
url: URL
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
|
|
@ -31,24 +26,41 @@ class DatasetDef(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class DatasetDefWithProvider(DatasetDef):
|
||||
provider_id: str = Field(
|
||||
description="ID of the provider which serves this dataset",
|
||||
)
|
||||
class Dataset(CommonDatasetFields, Resource):
|
||||
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_dataset_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||
dataset_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_dataset_id: Optional[str] = None
|
||||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets/register", method="POST")
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_def: DatasetDefWithProvider,
|
||||
dataset_id: str,
|
||||
dataset_schema: Dict[str, ParamType],
|
||||
url: URL,
|
||||
provider_dataset_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/datasets/get", method="GET")
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_identifier: str,
|
||||
) -> Optional[DatasetDefWithProvider]: ...
|
||||
dataset_id: str,
|
||||
) -> Optional[Dataset]: ...
|
||||
|
||||
@webmethod(route="/datasets/list", method="GET")
|
||||
async def list_datasets(self) -> List[DatasetDefWithProvider]: ...
|
||||
async def list_datasets(self) -> List[Dataset]: ...
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
|
|||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -35,36 +36,65 @@ EvalCandidate = Annotated[
|
|||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BenchmarkEvalTaskConfig(BaseModel):
|
||||
type: Literal["benchmark"] = "benchmark"
|
||||
eval_candidate: EvalCandidate
|
||||
num_examples: Optional[int] = Field(
|
||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AppEvalTaskConfig(BaseModel):
|
||||
type: Literal["app"] = "app"
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
default_factory=dict,
|
||||
)
|
||||
num_examples: Optional[int] = Field(
|
||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||
default=None,
|
||||
)
|
||||
# we could optinally add any specific dataset config here
|
||||
|
||||
|
||||
EvalTaskConfig = Annotated[
|
||||
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateResponse(BaseModel):
|
||||
generations: List[Dict[str, Any]]
|
||||
|
||||
# each key in the dict is a scoring function name
|
||||
scores: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class Eval(Protocol):
|
||||
@webmethod(route="/eval/evaluate_batch", method="POST")
|
||||
async def evaluate_batch(
|
||||
@webmethod(route="/eval/run_eval", method="POST")
|
||||
async def run_eval(
|
||||
self,
|
||||
dataset_id: str,
|
||||
candidate: EvalCandidate,
|
||||
scoring_functions: List[str],
|
||||
task_id: str,
|
||||
task_config: EvalTaskConfig,
|
||||
) -> Job: ...
|
||||
|
||||
@webmethod(route="/eval/evaluate", method="POST")
|
||||
async def evaluate(
|
||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
candidate: EvalCandidate,
|
||||
scoring_functions: List[str],
|
||||
task_config: EvalTaskConfig,
|
||||
) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/job/status", method="GET")
|
||||
async def job_status(self, job_id: str) -> Optional[JobStatus]: ...
|
||||
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||
|
||||
@webmethod(route="/eval/job/cancel", method="POST")
|
||||
async def job_cancel(self, job_id: str) -> None: ...
|
||||
async def job_cancel(self, task_id: str, job_id: str) -> None: ...
|
||||
|
||||
@webmethod(route="/eval/job/result", method="GET")
|
||||
async def job_result(self, job_id: str) -> EvaluateResponse: ...
|
||||
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...
|
||||
|
|
|
|||
7
llama_stack/apis/eval_tasks/__init__.py
Normal file
7
llama_stack/apis/eval_tasks/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .eval_tasks import * # noqa: F401 F403
|
||||
60
llama_stack/apis/eval_tasks/eval_tasks.py
Normal file
60
llama_stack/apis/eval_tasks/eval_tasks.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
class CommonEvalTaskFields(BaseModel):
|
||||
dataset_id: str
|
||||
scoring_functions: List[str]
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Metadata for this evaluation task",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvalTask(CommonEvalTaskFields, Resource):
|
||||
type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value
|
||||
|
||||
@property
|
||||
def eval_task_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_eval_task_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
|
||||
eval_task_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_eval_task_id: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EvalTasks(Protocol):
|
||||
@webmethod(route="/eval_tasks/list", method="GET")
|
||||
async def list_eval_tasks(self) -> List[EvalTask]: ...
|
||||
|
||||
@webmethod(route="/eval_tasks/get", method="GET")
|
||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
|
||||
|
||||
@webmethod(route="/eval_tasks/register", method="POST")
|
||||
async def register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_eval_task_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
|
@ -216,7 +216,7 @@ class EmbeddingsResponse(BaseModel):
|
|||
|
||||
|
||||
class ModelStore(Protocol):
|
||||
def get_model(self, identifier: str) -> ModelDef: ...
|
||||
def get_model(self, identifier: str) -> Model: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -226,7 +226,7 @@ class Inference(Protocol):
|
|||
@webmethod(route="/inference/completion")
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
|
|
@ -237,7 +237,7 @@ class Inference(Protocol):
|
|||
@webmethod(route="/inference/chat_completion")
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
|
|
@ -254,6 +254,6 @@ class Inference(Protocol):
|
|||
@webmethod(route="/inference/embeddings")
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse: ...
|
||||
|
|
|
|||
|
|
@ -75,14 +75,22 @@ class MemoryClient(Memory):
|
|||
async def run_main(host: str, port: int, stream: bool):
|
||||
banks_client = MemoryBanksClient(f"http://{host}:{port}")
|
||||
|
||||
bank = VectorMemoryBankDef(
|
||||
bank = VectorMemoryBank(
|
||||
identifier="test_bank",
|
||||
provider_id="",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
await banks_client.register_memory_bank(bank)
|
||||
await banks_client.register_memory_bank(
|
||||
bank.identifier,
|
||||
VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
provider_resource_id=bank.identifier,
|
||||
)
|
||||
|
||||
retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
|
||||
assert retrieved_bank is not None
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class QueryDocumentsResponse(BaseModel):
|
|||
|
||||
|
||||
class MemoryBankStore(Protocol):
|
||||
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
|
||||
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
|
@ -26,13 +25,13 @@ def deserialize_memory_bank_def(
|
|||
raise ValueError("Memory bank type not specified")
|
||||
type = j["type"]
|
||||
if type == MemoryBankType.vector.value:
|
||||
return VectorMemoryBankDef(**j)
|
||||
return VectorMemoryBank(**j)
|
||||
elif type == MemoryBankType.keyvalue.value:
|
||||
return KeyValueMemoryBankDef(**j)
|
||||
return KeyValueMemoryBank(**j)
|
||||
elif type == MemoryBankType.keyword.value:
|
||||
return KeywordMemoryBankDef(**j)
|
||||
return KeywordMemoryBank(**j)
|
||||
elif type == MemoryBankType.graph.value:
|
||||
return GraphMemoryBankDef(**j)
|
||||
return GraphMemoryBank(**j)
|
||||
else:
|
||||
raise ValueError(f"Unknown memory bank type: {type}")
|
||||
|
||||
|
|
@ -47,7 +46,7 @@ class MemoryBanksClient(MemoryBanks):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/memory_banks/list",
|
||||
|
|
@ -57,13 +56,20 @@ class MemoryBanksClient(MemoryBanks):
|
|||
return [deserialize_memory_bank_def(x) for x in response.json()]
|
||||
|
||||
async def register_memory_bank(
|
||||
self, memory_bank: MemoryBankDefWithProvider
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_resource_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/memory_banks/register",
|
||||
json={
|
||||
"memory_bank": json.loads(memory_bank.json()),
|
||||
"memory_bank_id": memory_bank_id,
|
||||
"provider_resource_id": provider_resource_id,
|
||||
"provider_id": provider_id,
|
||||
"params": params.dict(),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
|
@ -71,13 +77,13 @@ class MemoryBanksClient(MemoryBanks):
|
|||
|
||||
async def get_memory_bank(
|
||||
self,
|
||||
identifier: str,
|
||||
) -> Optional[MemoryBankDefWithProvider]:
|
||||
memory_bank_id: str,
|
||||
) -> Optional[MemoryBank]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/memory_banks/get",
|
||||
params={
|
||||
"identifier": identifier,
|
||||
"memory_bank_id": memory_bank_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
|
@ -94,12 +100,12 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
|
||||
# register memory bank for the first time
|
||||
response = await client.register_memory_bank(
|
||||
VectorMemoryBankDef(
|
||||
identifier="test_bank2",
|
||||
memory_bank_id="test_bank2",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
),
|
||||
)
|
||||
cprint(f"register_memory_bank response={response}", "blue")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,11 +5,21 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
from typing import (
|
||||
Annotated,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -20,59 +30,120 @@ class MemoryBankType(Enum):
|
|||
graph = "graph"
|
||||
|
||||
|
||||
class CommonDef(BaseModel):
|
||||
identifier: str
|
||||
# Hack: move this out later
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
# define params for each type of memory bank, this leads to a tagged union
|
||||
# accepted as input from the API or from the config.
|
||||
@json_schema_type
|
||||
class VectorMemoryBankDef(CommonDef):
|
||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
class VectorMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
embedding_model: str
|
||||
chunk_size_in_tokens: int
|
||||
overlap_size_in_tokens: Optional[int] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class KeyValueMemoryBankDef(CommonDef):
|
||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||
class KeyValueMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
||||
MemoryBankType.keyvalue.value
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class KeywordMemoryBankDef(CommonDef):
|
||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||
class KeywordMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
||||
MemoryBankType.keyword.value
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GraphMemoryBankDef(CommonDef):
|
||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
class GraphMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
|
||||
|
||||
MemoryBankDef = Annotated[
|
||||
BankParams = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankDef,
|
||||
KeyValueMemoryBankDef,
|
||||
KeywordMemoryBankDef,
|
||||
GraphMemoryBankDef,
|
||||
VectorMemoryBankParams,
|
||||
KeyValueMemoryBankParams,
|
||||
KeywordMemoryBankParams,
|
||||
GraphMemoryBankParams,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
Field(discriminator="memory_bank_type"),
|
||||
]
|
||||
|
||||
MemoryBankDefWithProvider = MemoryBankDef
|
||||
|
||||
# Some common functionality for memory banks.
|
||||
class MemoryBankResourceMixin(Resource):
|
||||
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
|
||||
|
||||
@property
|
||||
def memory_bank_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_memory_bank_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
embedding_model: str
|
||||
chunk_size_in_tokens: int
|
||||
overlap_size_in_tokens: Optional[int] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class KeyValueMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
||||
MemoryBankType.keyvalue.value
|
||||
)
|
||||
|
||||
|
||||
# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
|
||||
@json_schema_type
|
||||
class KeywordMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
||||
MemoryBankType.keyword.value
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GraphMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
|
||||
|
||||
MemoryBank = Annotated[
|
||||
Union[
|
||||
VectorMemoryBank,
|
||||
KeyValueMemoryBank,
|
||||
KeywordMemoryBank,
|
||||
GraphMemoryBank,
|
||||
],
|
||||
Field(discriminator="memory_bank_type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryBankInput(BaseModel):
|
||||
memory_bank_id: str
|
||||
params: BankParams
|
||||
provider_memory_bank_id: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MemoryBanks(Protocol):
|
||||
@webmethod(route="/memory_banks/list", method="GET")
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: ...
|
||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/get", method="GET")
|
||||
async def get_memory_bank(
|
||||
self, identifier: str
|
||||
) -> Optional[MemoryBankDefWithProvider]: ...
|
||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/register", method="POST")
|
||||
async def register_memory_bank(
|
||||
self, memory_bank: MemoryBankDefWithProvider
|
||||
) -> None: ...
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank: ...
|
||||
|
||||
@webmethod(route="/memory_banks/unregister", method="POST")
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||
|
|
|
|||
|
|
@ -26,16 +26,16 @@ class ModelsClient(Models):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||
async def list_models(self) -> List[Model]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/models/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [ModelDefWithProvider(**x) for x in response.json()]
|
||||
return [Model(**x) for x in response.json()]
|
||||
|
||||
async def register_model(self, model: ModelDefWithProvider) -> None:
|
||||
async def register_model(self, model: Model) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/models/register",
|
||||
|
|
@ -46,7 +46,7 @@ class ModelsClient(Models):
|
|||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
|
||||
async def get_model(self, identifier: str) -> Optional[Model]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/models/get",
|
||||
|
|
@ -59,7 +59,16 @@ class ModelsClient(Models):
|
|||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
return ModelDefWithProvider(**j)
|
||||
return Model(**j)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.delete(
|
||||
f"{self.base_url}/models/delete",
|
||||
params={"model_id": model_id},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
|
|
|
|||
|
|
@ -4,19 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
identifier: str = Field(
|
||||
description="A unique name for the model type",
|
||||
)
|
||||
llama_model: str = Field(
|
||||
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
|
||||
)
|
||||
|
||||
class CommonModelFields(BaseModel):
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this model",
|
||||
|
|
@ -24,19 +20,40 @@ class ModelDef(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ModelDefWithProvider(ModelDef):
|
||||
provider_id: str = Field(
|
||||
description="The provider ID for this model",
|
||||
)
|
||||
class Model(CommonModelFields, Resource):
|
||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_model_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class ModelInput(CommonModelFields):
|
||||
model_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_model_id: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models/list", method="GET")
|
||||
async def list_models(self) -> List[ModelDefWithProvider]: ...
|
||||
async def list_models(self) -> List[Model]: ...
|
||||
|
||||
@webmethod(route="/models/get", method="GET")
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: ...
|
||||
async def get_model(self, identifier: str) -> Optional[Model]: ...
|
||||
|
||||
@webmethod(route="/models/register", method="POST")
|
||||
async def register_model(self, model: ModelDefWithProvider) -> None: ...
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Model: ...
|
||||
|
||||
@webmethod(route="/models/unregister", method="POST")
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
|
|
|||
39
llama_stack/apis/resource.py
Normal file
39
llama_stack/apis/resource.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ResourceType(Enum):
|
||||
model = "model"
|
||||
shield = "shield"
|
||||
memory_bank = "memory_bank"
|
||||
dataset = "dataset"
|
||||
scoring_function = "scoring_function"
|
||||
eval_task = "eval_task"
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
"""Base class for all Llama Stack resources"""
|
||||
|
||||
identifier: str = Field(
|
||||
description="Unique identifier for this resource in llama stack"
|
||||
)
|
||||
|
||||
provider_resource_id: str = Field(
|
||||
description="Unique identifier for this resource in the provider",
|
||||
default=None,
|
||||
)
|
||||
|
||||
provider_id: str = Field(description="ID of the provider that owns this resource")
|
||||
|
||||
type: ResourceType = Field(
|
||||
description="Type of resource (e.g. 'model', 'shield', 'memory_bank', etc.)"
|
||||
)
|
||||
|
|
@ -27,7 +27,7 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
|
|||
|
||||
|
||||
def encodable_dict(d: BaseModel):
|
||||
return json.loads(d.json())
|
||||
return json.loads(d.model_dump_json())
|
||||
|
||||
|
||||
class SafetyClient(Safety):
|
||||
|
|
@ -41,13 +41,13 @@ class SafetyClient(Safety):
|
|||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message]
|
||||
self, shield_id: str, messages: List[Message]
|
||||
) -> RunShieldResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/safety/run_shield",
|
||||
json=dict(
|
||||
shield_type=shield_type,
|
||||
shield_id=shield_id,
|
||||
messages=[encodable_dict(m) for m in messages],
|
||||
),
|
||||
headers={
|
||||
|
|
@ -80,7 +80,7 @@ async def run_main(host: str, port: int, image_path: str = None):
|
|||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
response = await client.run_shield(
|
||||
shield_type="llama_guard",
|
||||
shield_id="Llama-Guard-3-1B",
|
||||
messages=[message],
|
||||
)
|
||||
print(response)
|
||||
|
|
@ -91,7 +91,7 @@ async def run_main(host: str, port: int, image_path: str = None):
|
|||
]:
|
||||
cprint(f"User>{message.content}", "green")
|
||||
response = await client.run_shield(
|
||||
shield_type="llama_guard",
|
||||
shield_id="llama_guard",
|
||||
messages=[message],
|
||||
)
|
||||
print(response)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class RunShieldResponse(BaseModel):
|
|||
|
||||
|
||||
class ShieldStore(Protocol):
|
||||
def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||
async def get_shield(self, identifier: str) -> Shield: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -48,5 +48,8 @@ class Safety(Protocol):
|
|||
|
||||
@webmethod(route="/safety/run_shield")
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse: ...
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class ScoreResponse(BaseModel):
|
|||
|
||||
|
||||
class ScoringFunctionStore(Protocol):
|
||||
def get_scoring_function(self, name: str) -> ScoringFnDefWithProvider: ...
|
||||
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -48,11 +48,13 @@ class Scoring(Protocol):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse: ...
|
||||
|
||||
@webmethod(route="/scoring/score")
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse: ...
|
||||
|
|
|
|||
|
|
@ -4,71 +4,119 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Parameter(BaseModel):
|
||||
name: str
|
||||
type: ParamType
|
||||
description: Optional[str] = None
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||
# with standard metrics so they can be rolled up?
|
||||
@json_schema_type
|
||||
class ScoringFnParamsType(Enum):
|
||||
llm_as_judge = "llm_as_judge"
|
||||
regex_parser = "regex_parser"
|
||||
|
||||
|
||||
class LLMAsJudgeContext(BaseModel):
|
||||
@json_schema_type
|
||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = (
|
||||
ScoringFnParamsType.llm_as_judge.value
|
||||
)
|
||||
judge_model: str
|
||||
prompt_template: Optional[str] = None
|
||||
judge_score_regex: Optional[List[str]] = Field(
|
||||
description="Regex to extract the score from the judge response",
|
||||
default=None,
|
||||
judge_score_regexes: Optional[List[str]] = Field(
|
||||
description="Regexes to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringFnDef(BaseModel):
|
||||
identifier: str
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.regex_parser.value] = (
|
||||
ScoringFnParamsType.regex_parser.value
|
||||
)
|
||||
parsing_regexes: Optional[List[str]] = Field(
|
||||
description="Regex to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
ScoringFnParams = Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class CommonScoringFnFields(BaseModel):
|
||||
description: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this definition",
|
||||
)
|
||||
parameters: List[Parameter] = Field(
|
||||
description="List of parameters for the deterministic function",
|
||||
default_factory=list,
|
||||
)
|
||||
return_type: ParamType = Field(
|
||||
description="The return type of the deterministic function",
|
||||
)
|
||||
context: Optional[LLMAsJudgeContext] = None
|
||||
# We can optionally add information here to support packaging of code, etc.
|
||||
params: Optional[ScoringFnParams] = Field(
|
||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringFnDefWithProvider(ScoringFnDef):
|
||||
provider_id: str = Field(
|
||||
description="ID of the provider which serves this dataset",
|
||||
class ScoringFn(CommonScoringFnFields, Resource):
|
||||
type: Literal[ResourceType.scoring_function.value] = (
|
||||
ResourceType.scoring_function.value
|
||||
)
|
||||
|
||||
@property
|
||||
def scoring_fn_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_scoring_fn_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||
scoring_fn_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_scoring_fn_id: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring_functions/list", method="GET")
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: ...
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||
|
||||
@webmethod(route="/scoring_functions/get", method="GET")
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
) -> Optional[ScoringFnDefWithProvider]: ...
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
|
||||
|
||||
@webmethod(route="/scoring_functions/register", method="POST")
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFnDefWithProvider
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
) -> None: ...
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
|
|
@ -26,32 +25,41 @@ class ShieldsClient(Shields):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_shields(self) -> List[ShieldDefWithProvider]:
|
||||
async def list_shields(self) -> List[Shield]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/shields/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [ShieldDefWithProvider(**x) for x in response.json()]
|
||||
return [Shield(**x) for x in response.json()]
|
||||
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str],
|
||||
provider_id: Optional[str],
|
||||
params: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/shields/register",
|
||||
json={
|
||||
"shield": json.loads(shield.json()),
|
||||
"shield_id": shield_id,
|
||||
"provider_shield_id": provider_shield_id,
|
||||
"provider_id": provider_id,
|
||||
"params": params,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||
async def get_shield(self, shield_id: str) -> Optional[Shield]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/shields/get",
|
||||
params={
|
||||
"shield_type": shield_type,
|
||||
"shield_id": shield_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
|
@ -61,7 +69,7 @@ class ShieldsClient(Shields):
|
|||
if j is None:
|
||||
return None
|
||||
|
||||
return ShieldDefWithProvider(**j)
|
||||
return Shield(**j)
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
|
|
|
|||
|
|
@ -4,48 +4,52 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
class CommonShieldFields(BaseModel):
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldType(Enum):
|
||||
generic_content_shield = "generic_content_shield"
|
||||
llama_guard = "llama_guard"
|
||||
code_scanner = "code_scanner"
|
||||
prompt_guard = "prompt_guard"
|
||||
class Shield(CommonShieldFields, Resource):
|
||||
"""A safety shield resource that can be used to check content"""
|
||||
|
||||
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
|
||||
|
||||
@property
|
||||
def shield_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_shield_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class ShieldDef(BaseModel):
|
||||
identifier: str = Field(
|
||||
description="A unique identifier for the shield type",
|
||||
)
|
||||
type: str = Field(
|
||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional parameters needed for this shield",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldDefWithProvider(ShieldDef):
|
||||
provider_id: str = Field(
|
||||
description="The provider ID for this shield type",
|
||||
)
|
||||
class ShieldInput(CommonShieldFields):
|
||||
shield_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_shield_id: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields/list", method="GET")
|
||||
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
|
||||
async def list_shields(self) -> List[Shield]: ...
|
||||
|
||||
@webmethod(route="/shields/get", method="GET")
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
||||
|
||||
@webmethod(route="/shields/register", method="POST")
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield: ...
|
||||
|
|
|
|||
|
|
@ -9,15 +9,27 @@ import asyncio
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import LlamaDownloadInfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
DownloadColumn,
|
||||
Progress,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
TransferSpeedColumn,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
|
@ -61,6 +73,13 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
|||
required=False,
|
||||
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-parallel",
|
||||
type=int,
|
||||
required=False,
|
||||
default=3,
|
||||
help="Maximum number of concurrent downloads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore-patterns",
|
||||
type=str,
|
||||
|
|
@ -80,6 +99,245 @@ safetensors files to avoid downloading duplicate weights.
|
|||
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadTask:
|
||||
url: str
|
||||
output_file: str
|
||||
total_size: int = 0
|
||||
downloaded_size: int = 0
|
||||
task_id: Optional[int] = None
|
||||
retries: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CustomTransferSpeedColumn(TransferSpeedColumn):
|
||||
def render(self, task):
|
||||
if task.finished:
|
||||
return "-"
|
||||
return super().render(task)
|
||||
|
||||
|
||||
class ParallelDownloader:
|
||||
def __init__(
|
||||
self,
|
||||
max_concurrent_downloads: int = 3,
|
||||
buffer_size: int = 1024 * 1024,
|
||||
timeout: int = 30,
|
||||
):
|
||||
self.max_concurrent_downloads = max_concurrent_downloads
|
||||
self.buffer_size = buffer_size
|
||||
self.timeout = timeout
|
||||
self.console = Console()
|
||||
self.progress = Progress(
|
||||
TextColumn("[bold blue]{task.description}"),
|
||||
BarColumn(bar_width=40),
|
||||
"[progress.percentage]{task.percentage:>3.1f}%",
|
||||
DownloadColumn(),
|
||||
CustomTransferSpeedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=self.console,
|
||||
expand=True,
|
||||
)
|
||||
self.client_options = {
|
||||
"timeout": httpx.Timeout(timeout),
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
async def retry_with_exponential_backoff(
|
||||
self, task: DownloadTask, func, *args, **kwargs
|
||||
):
|
||||
last_exception = None
|
||||
for attempt in range(task.max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < task.max_retries - 1:
|
||||
wait_time = min(30, 2**attempt) # Cap at 30 seconds
|
||||
self.console.print(
|
||||
f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, "
|
||||
f"retrying in {wait_time} seconds: {str(e)}[/yellow]"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
raise last_exception
|
||||
|
||||
async def get_file_info(
|
||||
self, client: httpx.AsyncClient, task: DownloadTask
|
||||
) -> None:
|
||||
async def _get_info():
|
||||
response = await client.head(
|
||||
task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
try:
|
||||
response = await self.retry_with_exponential_backoff(task, _get_info)
|
||||
|
||||
task.url = str(response.url)
|
||||
task.total_size = int(response.headers.get("Content-Length", 0))
|
||||
|
||||
if task.total_size == 0:
|
||||
raise DownloadError(
|
||||
f"Unable to determine file size for {task.output_file}. "
|
||||
"The server might not support range requests."
|
||||
)
|
||||
|
||||
# Update the progress bar's total size once we know it
|
||||
if task.task_id is not None:
|
||||
self.progress.update(task.task_id, total=task.total_size)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
|
||||
raise
|
||||
|
||||
def verify_file_integrity(self, task: DownloadTask) -> bool:
|
||||
if not os.path.exists(task.output_file):
|
||||
return False
|
||||
return os.path.getsize(task.output_file) == task.total_size
|
||||
|
||||
async def download_chunk(
|
||||
self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
|
||||
) -> None:
|
||||
async def _download_chunk():
|
||||
headers = {"Range": f"bytes={start}-{end}"}
|
||||
async with client.stream(
|
||||
"GET", task.url, headers=headers, **self.client_options
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
with open(task.output_file, "ab") as file:
|
||||
file.seek(start)
|
||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||
file.write(chunk)
|
||||
task.downloaded_size += len(chunk)
|
||||
self.progress.update(
|
||||
task.task_id,
|
||||
completed=task.downloaded_size,
|
||||
)
|
||||
|
||||
try:
|
||||
await self.retry_with_exponential_backoff(task, _download_chunk)
|
||||
except Exception as e:
|
||||
raise DownloadError(
|
||||
f"Failed to download chunk {start}-{end} after "
|
||||
f"{task.max_retries} attempts: {str(e)}"
|
||||
) from e
|
||||
|
||||
async def prepare_download(self, task: DownloadTask) -> None:
|
||||
output_dir = os.path.dirname(task.output_file)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if os.path.exists(task.output_file):
|
||||
task.downloaded_size = os.path.getsize(task.output_file)
|
||||
|
||||
async def download_file(self, task: DownloadTask) -> None:
|
||||
try:
|
||||
async with httpx.AsyncClient(**self.client_options) as client:
|
||||
await self.get_file_info(client, task)
|
||||
|
||||
# Check if file is already downloaded
|
||||
if os.path.exists(task.output_file):
|
||||
if self.verify_file_integrity(task):
|
||||
self.console.print(
|
||||
f"[green]Already downloaded {task.output_file}[/green]"
|
||||
)
|
||||
self.progress.update(task.task_id, completed=task.total_size)
|
||||
return
|
||||
|
||||
await self.prepare_download(task)
|
||||
|
||||
try:
|
||||
# Split the remaining download into chunks
|
||||
chunk_size = 27_000_000_000 # Cloudfront max chunk size
|
||||
chunks = []
|
||||
|
||||
current_pos = task.downloaded_size
|
||||
while current_pos < task.total_size:
|
||||
chunk_end = min(
|
||||
current_pos + chunk_size - 1, task.total_size - 1
|
||||
)
|
||||
chunks.append((current_pos, chunk_end))
|
||||
current_pos = chunk_end + 1
|
||||
|
||||
# Download chunks in sequence
|
||||
for chunk_start, chunk_end in chunks:
|
||||
await self.download_chunk(client, task, chunk_start, chunk_end)
|
||||
|
||||
except Exception as e:
|
||||
raise DownloadError(f"Download failed: {str(e)}") from e
|
||||
|
||||
except Exception as e:
|
||||
self.progress.update(
|
||||
task.task_id, description=f"[red]Failed: {task.output_file}[/red]"
|
||||
)
|
||||
raise DownloadError(
|
||||
f"Download failed for {task.output_file}: {str(e)}"
|
||||
) from e
|
||||
|
||||
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
||||
try:
|
||||
total_remaining_size = sum(
|
||||
task.total_size - task.downloaded_size for task in tasks
|
||||
)
|
||||
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
||||
free_space = shutil.disk_usage(dir_path).free
|
||||
|
||||
# Add 10% buffer for safety
|
||||
required_space = int(total_remaining_size * 1.1)
|
||||
|
||||
if free_space < required_space:
|
||||
self.console.print(
|
||||
f"[red]Not enough disk space. Required: {required_space // (1024*1024)} MB, "
|
||||
f"Available: {free_space // (1024*1024)} MB[/red]"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
||||
|
||||
async def download_all(self, tasks: List[DownloadTask]) -> None:
|
||||
if not tasks:
|
||||
raise ValueError("No download tasks provided")
|
||||
|
||||
if not self.has_disk_space(tasks):
|
||||
raise DownloadError("Insufficient disk space for downloads")
|
||||
|
||||
failed_tasks = []
|
||||
|
||||
with self.progress:
|
||||
for task in tasks:
|
||||
desc = f"Downloading {Path(task.output_file).name}"
|
||||
task.task_id = self.progress.add_task(
|
||||
desc, total=task.total_size, completed=task.downloaded_size
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
|
||||
|
||||
async def download_with_semaphore(task: DownloadTask):
|
||||
async with semaphore:
|
||||
try:
|
||||
await self.download_file(task)
|
||||
except Exception as e:
|
||||
failed_tasks.append((task, str(e)))
|
||||
|
||||
await asyncio.gather(*(download_with_semaphore(task) for task in tasks))
|
||||
|
||||
if failed_tasks:
|
||||
self.console.print("\n[red]Some downloads failed:[/red]")
|
||||
for task, error in failed_tasks:
|
||||
self.console.print(
|
||||
f"[red]- {Path(task.output_file).name}: {error}[/red]"
|
||||
)
|
||||
raise DownloadError(f"{len(failed_tasks)} downloads failed")
|
||||
|
||||
|
||||
def _hf_download(
|
||||
model: "Model",
|
||||
hf_token: str,
|
||||
|
|
@ -120,63 +378,37 @@ def _hf_download(
|
|||
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
||||
|
||||
|
||||
def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
|
||||
def _meta_download(
|
||||
model: "Model",
|
||||
meta_url: str,
|
||||
info: "LlamaDownloadInfo",
|
||||
max_concurrent_downloads: int,
|
||||
):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
output_dir = Path(model_local_dir(model.descriptor()))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# I believe we can use some concurrency here if needed but not sure it is worth it
|
||||
# Create download tasks for each file
|
||||
tasks = []
|
||||
for f in info.files:
|
||||
output_file = str(output_dir / f)
|
||||
url = meta_url.replace("*", f"{info.folder}/{f}")
|
||||
total_size = info.pth_size if "consolidated" in f else 0
|
||||
cprint(f"Downloading `{f}`...", "white")
|
||||
downloader = ResumableDownloader(url, output_file, total_size)
|
||||
asyncio.run(downloader.download())
|
||||
tasks.append(
|
||||
DownloadTask(
|
||||
url=url, output_file=output_file, total_size=total_size, max_retries=3
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
print(f"\nSuccessfully downloaded model to {output_dir}")
|
||||
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
|
||||
|
||||
|
||||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
from llama_models.sku_list import llama_meta_net_info, resolve_model
|
||||
|
||||
from .model.safety_models import prompt_guard_download_info, prompt_guard_model_sku
|
||||
|
||||
if args.manifest_file:
|
||||
_download_from_manifest(args.manifest_file)
|
||||
return
|
||||
|
||||
if args.model_id is None:
|
||||
parser.error("Please provide a model id")
|
||||
return
|
||||
|
||||
# Check if model_id is a comma-separated list
|
||||
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
||||
|
||||
prompt_guard = prompt_guard_model_sku()
|
||||
for model_id in model_ids:
|
||||
if model_id == prompt_guard.model_id:
|
||||
model = prompt_guard
|
||||
info = prompt_guard_download_info()
|
||||
else:
|
||||
model = resolve_model(model_id)
|
||||
if model is None:
|
||||
parser.error(f"Model {model_id} not found")
|
||||
continue
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
if args.source == "huggingface":
|
||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||
else:
|
||||
meta_url = args.meta_url or input(
|
||||
f"Please provide the signed URL for model {model_id} you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
)
|
||||
assert "llamameta.net" in meta_url
|
||||
_meta_download(model, meta_url, info)
|
||||
|
||||
|
||||
class ModelEntry(BaseModel):
|
||||
model_id: str
|
||||
files: Dict[str, str]
|
||||
|
|
@ -190,7 +422,7 @@ class Manifest(BaseModel):
|
|||
expires_on: datetime
|
||||
|
||||
|
||||
def _download_from_manifest(manifest_file: str):
|
||||
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
with open(manifest_file, "r") as f:
|
||||
|
|
@ -200,143 +432,88 @@ def _download_from_manifest(manifest_file: str):
|
|||
if datetime.now() > manifest.expires_on:
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
console = Console()
|
||||
for entry in manifest.models:
|
||||
print(f"Downloading model {entry.model_id}...")
|
||||
console.print(f"[blue]Downloading model {entry.model_id}...[/blue]")
|
||||
output_dir = Path(model_local_dir(entry.model_id))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if any(output_dir.iterdir()):
|
||||
cprint(f"Output directory {output_dir} is not empty.", "red")
|
||||
console.print(
|
||||
f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
|
||||
)
|
||||
|
||||
while True:
|
||||
resp = input(
|
||||
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
|
||||
)
|
||||
if resp.lower() == "restart" or resp.lower() == "r":
|
||||
if resp.lower() in ["restart", "r"]:
|
||||
shutil.rmtree(output_dir)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
break
|
||||
elif resp.lower() == "continue" or resp.lower() == "c":
|
||||
print("Continuing download...")
|
||||
elif resp.lower() in ["continue", "c"]:
|
||||
console.print("[blue]Continuing download...[/blue]")
|
||||
break
|
||||
else:
|
||||
cprint("Invalid response. Please try again.", "red")
|
||||
console.print("[red]Invalid response. Please try again.[/red]")
|
||||
|
||||
for fname, url in entry.files.items():
|
||||
output_file = str(output_dir / fname)
|
||||
downloader = ResumableDownloader(url, output_file)
|
||||
asyncio.run(downloader.download())
|
||||
# Create download tasks for all files in the manifest
|
||||
tasks = [
|
||||
DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3)
|
||||
for fname, url in entry.files.items()
|
||||
]
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(
|
||||
max_concurrent_downloads=max_concurrent_downloads
|
||||
)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
|
||||
class ResumableDownloader:
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
output_file: str,
|
||||
total_size: int = 0,
|
||||
buffer_size: int = 32 * 1024,
|
||||
):
|
||||
self.url = url
|
||||
self.output_file = output_file
|
||||
self.buffer_size = buffer_size
|
||||
self.total_size = total_size
|
||||
self.downloaded_size = 0
|
||||
self.start_size = 0
|
||||
self.start_time = 0
|
||||
|
||||
async def get_file_info(self, client: httpx.AsyncClient) -> None:
|
||||
if self.total_size > 0:
|
||||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
"""Main download command handler"""
|
||||
try:
|
||||
if args.manifest_file:
|
||||
_download_from_manifest(args.manifest_file, args.max_parallel)
|
||||
return
|
||||
|
||||
# Force disable compression when trying to retrieve file size
|
||||
response = await client.head(
|
||||
self.url, follow_redirects=True, headers={"Accept-Encoding": "identity"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
self.url = str(response.url) # Update URL in case of redirects
|
||||
self.total_size = int(response.headers.get("Content-Length", 0))
|
||||
if self.total_size == 0:
|
||||
raise ValueError(
|
||||
"Unable to determine file size. The server might not support range requests."
|
||||
)
|
||||
if args.model_id is None:
|
||||
parser.error("Please provide a model id")
|
||||
return
|
||||
|
||||
async def download(self) -> None:
|
||||
self.start_time = time.time()
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
await self.get_file_info(client)
|
||||
# Handle comma-separated model IDs
|
||||
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
||||
|
||||
if os.path.exists(self.output_file):
|
||||
self.downloaded_size = os.path.getsize(self.output_file)
|
||||
self.start_size = self.downloaded_size
|
||||
if self.downloaded_size >= self.total_size:
|
||||
print(f"Already downloaded `{self.output_file}`, skipping...")
|
||||
return
|
||||
from llama_models.sku_list import llama_meta_net_info, resolve_model
|
||||
|
||||
additional_size = self.total_size - self.downloaded_size
|
||||
if not self.has_disk_space(additional_size):
|
||||
M = 1024 * 1024 # noqa
|
||||
print(
|
||||
f"Not enough disk space to download `{self.output_file}`. "
|
||||
f"Required: {(additional_size // M):.2f} MB"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Not enough disk space to download `{self.output_file}`"
|
||||
)
|
||||
|
||||
while True:
|
||||
if self.downloaded_size >= self.total_size:
|
||||
break
|
||||
|
||||
# Cloudfront has a max-size limit
|
||||
max_chunk_size = 27_000_000_000
|
||||
request_size = min(
|
||||
self.total_size - self.downloaded_size, max_chunk_size
|
||||
)
|
||||
headers = {
|
||||
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
|
||||
}
|
||||
print(f"Downloading `{self.output_file}`....{headers}")
|
||||
try:
|
||||
async with client.stream(
|
||||
"GET", self.url, headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
with open(self.output_file, "ab") as file:
|
||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||
file.write(chunk)
|
||||
self.downloaded_size += len(chunk)
|
||||
self.print_progress()
|
||||
except httpx.HTTPError as e:
|
||||
print(f"\nDownload interrupted: {e}")
|
||||
print("You can resume the download by running the script again.")
|
||||
except Exception as e:
|
||||
print(f"\nAn error occurred: {e}")
|
||||
|
||||
print(f"\nFinished downloading `{self.output_file}`....")
|
||||
|
||||
def print_progress(self) -> None:
|
||||
percent = (self.downloaded_size / self.total_size) * 100
|
||||
bar_length = 50
|
||||
filled_length = int(bar_length * self.downloaded_size // self.total_size)
|
||||
bar = "█" * filled_length + "-" * (bar_length - filled_length)
|
||||
|
||||
elapsed_time = time.time() - self.start_time
|
||||
M = 1024 * 1024 # noqa
|
||||
|
||||
speed = (
|
||||
(self.downloaded_size - self.start_size) / (elapsed_time * M)
|
||||
if elapsed_time > 0
|
||||
else 0
|
||||
)
|
||||
print(
|
||||
f"\rProgress: |{bar}| {percent:.2f}% "
|
||||
f"({self.downloaded_size // M}/{self.total_size // M} MB) "
|
||||
f"Speed: {speed:.2f} MiB/s",
|
||||
end="",
|
||||
flush=True,
|
||||
from .model.safety_models import (
|
||||
prompt_guard_download_info,
|
||||
prompt_guard_model_sku,
|
||||
)
|
||||
|
||||
def has_disk_space(self, file_size: int) -> bool:
|
||||
dir_path = os.path.dirname(os.path.abspath(self.output_file))
|
||||
free_space = shutil.disk_usage(dir_path).free
|
||||
return free_space > file_size
|
||||
prompt_guard = prompt_guard_model_sku()
|
||||
for model_id in model_ids:
|
||||
if model_id == prompt_guard.model_id:
|
||||
model = prompt_guard
|
||||
info = prompt_guard_download_info()
|
||||
else:
|
||||
model = resolve_model(model_id)
|
||||
if model is None:
|
||||
parser.error(f"Model {model_id} not found")
|
||||
continue
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
if args.source == "huggingface":
|
||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||
else:
|
||||
meta_url = args.meta_url or input(
|
||||
f"Please provide the signed URL for model {model_id} you received via email "
|
||||
f"after visiting https://www.llama.com/llama-downloads/ "
|
||||
f"(e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
)
|
||||
if "llamameta.net" not in meta_url:
|
||||
parser.error("Invalid Meta URL provided")
|
||||
_meta_download(model, meta_url, info, args.max_parallel)
|
||||
|
||||
except Exception as e:
|
||||
parser.error(f"Download failed: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import argparse
|
|||
from .download import Download
|
||||
from .model import ModelParser
|
||||
from .stack import StackParser
|
||||
from .verify_download import VerifyDownload
|
||||
|
||||
|
||||
class LlamaCLIParser:
|
||||
|
|
@ -27,9 +28,10 @@ class LlamaCLIParser:
|
|||
subparsers = self.parser.add_subparsers(title="subcommands")
|
||||
|
||||
# Add sub-commands
|
||||
Download.create(subparsers)
|
||||
ModelParser.create(subparsers)
|
||||
StackParser.create(subparsers)
|
||||
Download.create(subparsers)
|
||||
VerifyDownload.create(subparsers)
|
||||
|
||||
def parse_args(self) -> argparse.Namespace:
|
||||
return self.parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe
|
|||
from llama_stack.cli.model.download import ModelDownload
|
||||
from llama_stack.cli.model.list import ModelList
|
||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
|
@ -32,3 +33,4 @@ class ModelParser(Subcommand):
|
|||
ModelList.create(subparsers)
|
||||
ModelPromptFormat.create(subparsers)
|
||||
ModelDescribe.create(subparsers)
|
||||
ModelVerifyDownload.create(subparsers)
|
||||
|
|
|
|||
24
llama_stack/cli/model/verify_download.py
Normal file
24
llama_stack/cli/model/verify_download.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class ModelVerifyDownload(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"verify-download",
|
||||
prog="llama model verify-download",
|
||||
description="Verify the downloaded checkpoints' checksums",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
from llama_stack.cli.verify_download import setup_verify_download_parser
|
||||
|
||||
setup_verify_download_parser(self.parser)
|
||||
|
|
@ -12,6 +12,10 @@ import os
|
|||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
TEMPLATES_PATH = Path(os.path.relpath(__file__)).parent.parent.parent / "templates"
|
||||
|
||||
|
||||
|
|
@ -176,6 +180,66 @@ class StackBuild(Subcommand):
|
|||
return
|
||||
self._run_stack_build_command_from_build_config(build_config)
|
||||
|
||||
def _generate_run_config(self, build_config: BuildConfig, build_dir: Path) -> None:
|
||||
"""
|
||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||
"""
|
||||
import json
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
|
||||
apis = list(build_config.distribution_spec.providers.keys())
|
||||
run_config = StackRunConfig(
|
||||
built_at=datetime.now(),
|
||||
docker_image=(
|
||||
build_config.name
|
||||
if build_config.image_type == ImageType.docker.value
|
||||
else None
|
||||
),
|
||||
image_name=build_config.name,
|
||||
conda_env=(
|
||||
build_config.name
|
||||
if build_config.image_type == ImageType.conda.value
|
||||
else None
|
||||
),
|
||||
apis=apis,
|
||||
providers={},
|
||||
)
|
||||
# build providers dict
|
||||
provider_registry = get_provider_registry()
|
||||
for api in apis:
|
||||
run_config.providers[api] = []
|
||||
provider_types = build_config.distribution_spec.providers[api]
|
||||
if isinstance(provider_types, str):
|
||||
provider_types = [provider_types]
|
||||
|
||||
for i, provider_type in enumerate(provider_types):
|
||||
p_spec = Provider(
|
||||
provider_id=f"{provider_type}-{i}",
|
||||
provider_type=provider_type,
|
||||
config={},
|
||||
)
|
||||
config_type = instantiate_class_type(
|
||||
provider_registry[Api(api)][provider_type].config_class
|
||||
)
|
||||
p_spec.config = config_type()
|
||||
run_config.providers[api].append(p_spec)
|
||||
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
run_config_file = build_dir / f"{build_config.name}-run.yaml"
|
||||
|
||||
with open(run_config_file, "w") as f:
|
||||
to_write = json.loads(run_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
cprint(
|
||||
f"You can now edit {run_config_file} and run `llama stack run {run_config_file}`",
|
||||
color="green",
|
||||
)
|
||||
|
||||
def _run_stack_build_command_from_build_config(
|
||||
self, build_config: BuildConfig
|
||||
) -> None:
|
||||
|
|
@ -183,48 +247,24 @@ class StackBuild(Subcommand):
|
|||
import os
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import build_image, ImageType
|
||||
from llama_stack.distribution.build import build_image
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
|
||||
# save build.yaml spec for building same distribution again
|
||||
if build_config.image_type == ImageType.docker.value:
|
||||
# docker needs build file to be in the llama-stack repo dir to be able to copy over to the image
|
||||
llama_stack_path = Path(
|
||||
os.path.abspath(__file__)
|
||||
).parent.parent.parent.parent
|
||||
build_dir = llama_stack_path / "tmp/configs/"
|
||||
else:
|
||||
build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}"
|
||||
|
||||
build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}"
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
build_file_path = build_dir / f"{build_config.name}-build.yaml"
|
||||
|
||||
with open(build_file_path, "w") as f:
|
||||
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
|
||||
to_write = json.loads(build_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
return_code = build_image(build_config, build_file_path)
|
||||
if return_code != 0:
|
||||
return
|
||||
|
||||
configure_name = (
|
||||
build_config.name
|
||||
if build_config.image_type == "conda"
|
||||
else (f"llamastack-{build_config.name}")
|
||||
)
|
||||
if build_config.image_type == "conda":
|
||||
cprint(
|
||||
f"You can now run `llama stack configure {configure_name}`",
|
||||
color="green",
|
||||
)
|
||||
else:
|
||||
cprint(
|
||||
f"You can now edit your run.yaml file and run `docker run -it -p 5000:5000 {build_config.name}`. See full command in llama-stack/distributions/",
|
||||
color="green",
|
||||
)
|
||||
self._generate_run_config(build_config, build_dir)
|
||||
|
||||
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
import json
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@
|
|||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class StackConfigure(Subcommand):
|
||||
|
|
@ -39,123 +37,10 @@ class StackConfigure(Subcommand):
|
|||
)
|
||||
|
||||
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
docker_image = None
|
||||
|
||||
build_config_file = Path(args.config)
|
||||
if build_config_file.exists():
|
||||
with open(build_config_file, "r") as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
self._configure_llama_distribution(build_config, args.output_dir)
|
||||
return
|
||||
|
||||
conda_dir = (
|
||||
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
|
||||
)
|
||||
output = subprocess.check_output(["bash", "-c", "conda info --json"])
|
||||
conda_envs = json.loads(output.decode("utf-8"))["envs"]
|
||||
|
||||
for x in conda_envs:
|
||||
if x.endswith(f"/llamastack-{args.config}"):
|
||||
conda_dir = Path(x)
|
||||
break
|
||||
|
||||
build_config_file = Path(conda_dir) / f"{args.config}-build.yaml"
|
||||
if build_config_file.exists():
|
||||
with open(build_config_file, "r") as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
|
||||
cprint(f"Using {build_config_file}...", "green")
|
||||
self._configure_llama_distribution(build_config, args.output_dir)
|
||||
return
|
||||
|
||||
docker_image = args.config
|
||||
builds_dir = BUILDS_BASE_DIR / ImageType.docker.value
|
||||
if args.output_dir:
|
||||
builds_dir = Path(output_dir)
|
||||
os.makedirs(builds_dir, exist_ok=True)
|
||||
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_stack", "distribution/configure_container.sh"
|
||||
)
|
||||
script_args = [script, docker_image, str(builds_dir)]
|
||||
|
||||
return_code = run_with_pty(script_args)
|
||||
if return_code != 0:
|
||||
self.parser.error(
|
||||
f"Failed to configure container {docker_image} with return code {return_code}. Please run `llama stack build` first. "
|
||||
)
|
||||
|
||||
def _configure_llama_distribution(
|
||||
self,
|
||||
build_config: BuildConfig,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.configure import (
|
||||
configure_api_providers,
|
||||
parse_and_maybe_upgrade_config,
|
||||
)
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
|
||||
builds_dir = BUILDS_BASE_DIR / build_config.image_type
|
||||
if output_dir:
|
||||
builds_dir = Path(output_dir)
|
||||
os.makedirs(builds_dir, exist_ok=True)
|
||||
image_name = build_config.name.replace("::", "-")
|
||||
run_config_file = builds_dir / f"{image_name}-run.yaml"
|
||||
|
||||
if run_config_file.exists():
|
||||
cprint(
|
||||
f"Configuration already exists at `{str(run_config_file)}`. Will overwrite...",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
config_dict = yaml.safe_load(run_config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
config = StackRunConfig(
|
||||
built_at=datetime.now(),
|
||||
image_name=image_name,
|
||||
apis=list(build_config.distribution_spec.providers.keys()),
|
||||
providers={},
|
||||
)
|
||||
|
||||
config = configure_api_providers(config, build_config.distribution_spec)
|
||||
|
||||
config.docker_image = (
|
||||
image_name if build_config.image_type == "docker" else None
|
||||
)
|
||||
config.conda_env = image_name if build_config.image_type == "conda" else None
|
||||
|
||||
with open(run_config_file, "w") as f:
|
||||
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
cprint(
|
||||
f"> YAML configuration has been written to `{run_config_file}`.",
|
||||
color="blue",
|
||||
)
|
||||
|
||||
cprint(
|
||||
f"You can now run `llama stack run {image_name} --port PORT`",
|
||||
color="green",
|
||||
self.parser.error(
|
||||
"""
|
||||
DEPRECATED! llama stack configure has been deprecated.
|
||||
Please use llama stack run <path/to/run.yaml> instead.
|
||||
Please see example run.yaml in /distributions folder.
|
||||
"""
|
||||
)
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ class StackRun(Subcommand):
|
|||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
|
|
@ -71,14 +70,12 @@ class StackRun(Subcommand):
|
|||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist. Please run `llama stack build` and `llama stack configure <name>` to generate a run.yaml file"
|
||||
f"File {str(config_file)} does not exist. Please run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
return
|
||||
|
||||
cprint(f"Using config `{config_file}`", "green")
|
||||
with open(config_file, "r") as f:
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
|
||||
if config.docker_image:
|
||||
script = pkg_resources.resource_filename(
|
||||
|
|
|
|||
|
|
@ -25,11 +25,11 @@ def up_to_date_config():
|
|||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
safety:
|
||||
- provider_id: provider1
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
|
|
@ -39,7 +39,7 @@ def up_to_date_config():
|
|||
enable_prompt_guard: false
|
||||
memory:
|
||||
- provider_id: provider1
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
""".format(
|
||||
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
|
||||
|
|
@ -61,13 +61,13 @@ def old_config():
|
|||
host: localhost
|
||||
port: 11434
|
||||
routing_key: Llama3.2-1B-Instruct
|
||||
- provider_type: meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- routing_key: ["shield1", "shield2"]
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
|
|
@ -77,7 +77,7 @@ def old_config():
|
|||
enable_prompt_guard: false
|
||||
memory:
|
||||
- routing_key: vector
|
||||
provider_type: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
api_providers:
|
||||
telemetry:
|
||||
|
|
|
|||
144
llama_stack/cli/verify_download.py
Normal file
144
llama_stack/cli/verify_download.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
filename: str
|
||||
expected_hash: str
|
||||
actual_hash: Optional[str]
|
||||
exists: bool
|
||||
matches: bool
|
||||
|
||||
|
||||
class VerifyDownload(Subcommand):
|
||||
"""Llama cli for verifying downloaded model files"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"verify-download",
|
||||
prog="llama verify-download",
|
||||
description="Verify integrity of downloaded model files",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
setup_verify_download_parser(self.parser)
|
||||
|
||||
|
||||
def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
required=True,
|
||||
help="Model ID to verify",
|
||||
)
|
||||
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
|
||||
|
||||
|
||||
def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
|
||||
md5_hash = hashlib.md5()
|
||||
with open(filepath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
md5_hash.update(chunk)
|
||||
return md5_hash.hexdigest()
|
||||
|
||||
|
||||
def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
||||
checksums = {}
|
||||
with open(checklist_path, "r") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
md5sum, filepath = line.strip().split(" ", 1)
|
||||
# Remove leading './' if present
|
||||
filepath = filepath.lstrip("./")
|
||||
checksums[filepath] = md5sum
|
||||
return checksums
|
||||
|
||||
|
||||
def verify_files(
|
||||
model_dir: Path, checksums: Dict[str, str], console: Console
|
||||
) -> List[VerificationResult]:
|
||||
results = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
console=console,
|
||||
) as progress:
|
||||
for filepath, expected_hash in checksums.items():
|
||||
full_path = model_dir / filepath
|
||||
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
|
||||
|
||||
exists = full_path.exists()
|
||||
actual_hash = None
|
||||
matches = False
|
||||
|
||||
if exists:
|
||||
actual_hash = calculate_md5(full_path)
|
||||
matches = actual_hash == expected_hash
|
||||
|
||||
results.append(
|
||||
VerificationResult(
|
||||
filename=filepath,
|
||||
expected_hash=expected_hash,
|
||||
actual_hash=actual_hash,
|
||||
exists=exists,
|
||||
matches=matches,
|
||||
)
|
||||
)
|
||||
|
||||
progress.remove_task(task_id)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
console = Console()
|
||||
model_dir = Path(model_local_dir(args.model_id))
|
||||
checklist_path = model_dir / "checklist.chk"
|
||||
|
||||
if not model_dir.exists():
|
||||
parser.error(f"Model directory not found: {model_dir}")
|
||||
|
||||
if not checklist_path.exists():
|
||||
parser.error(f"Checklist file not found: {checklist_path}")
|
||||
|
||||
checksums = load_checksums(checklist_path)
|
||||
results = verify_files(model_dir, checksums, console)
|
||||
|
||||
# Print results
|
||||
console.print("\nVerification Results:")
|
||||
|
||||
all_good = True
|
||||
for result in results:
|
||||
if not result.exists:
|
||||
console.print(f"[red]❌ {result.filename}: File not found[/red]")
|
||||
all_good = False
|
||||
elif not result.matches:
|
||||
console.print(
|
||||
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
|
||||
f" Expected: {result.expected_hash}\n"
|
||||
f" Got: {result.actual_hash}"
|
||||
)
|
||||
all_good = False
|
||||
else:
|
||||
console.print(f"[green]✓ {result.filename}: Verified[/green]")
|
||||
|
||||
if all_good:
|
||||
console.print("\n[green]All files verified successfully![/green]")
|
||||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import pkg_resources
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -25,6 +25,7 @@ from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
|||
# These are the dependencies needed by the distribution server.
|
||||
# `llama-stack` is automatically installed by the installation script.
|
||||
SERVER_DEPENDENCIES = [
|
||||
"aiosqlite",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
|
|
@ -37,28 +38,19 @@ class ImageType(Enum):
|
|||
conda = "conda"
|
||||
|
||||
|
||||
class Dependencies(BaseModel):
|
||||
pip_packages: List[str]
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
|
||||
class ApiInput(BaseModel):
|
||||
api: Api
|
||||
provider: str
|
||||
|
||||
|
||||
def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||
package_deps = Dependencies(
|
||||
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
|
||||
pip_packages=SERVER_DEPENDENCIES,
|
||||
)
|
||||
|
||||
# extend package dependencies based on providers spec
|
||||
def get_provider_dependencies(
|
||||
config_providers: Dict[str, List[Provider]]
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
all_providers = get_provider_registry()
|
||||
for (
|
||||
api_str,
|
||||
provider_or_providers,
|
||||
) in build_config.distribution_spec.providers.items():
|
||||
deps = []
|
||||
|
||||
for api_str, provider_or_providers in config_providers.items():
|
||||
providers_for_api = all_providers[Api(api_str)]
|
||||
|
||||
providers = (
|
||||
|
|
@ -68,25 +60,50 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
)
|
||||
|
||||
for provider in providers:
|
||||
if provider not in providers_for_api:
|
||||
# Providers from BuildConfig and RunConfig are subtly different – not great
|
||||
provider_type = (
|
||||
provider if isinstance(provider, str) else provider.provider_type
|
||||
)
|
||||
|
||||
if provider_type not in providers_for_api:
|
||||
raise ValueError(
|
||||
f"Provider `{provider}` is not available for API `{api_str}`"
|
||||
)
|
||||
|
||||
provider_spec = providers_for_api[provider]
|
||||
package_deps.pip_packages.extend(provider_spec.pip_packages)
|
||||
provider_spec = providers_for_api[provider_type]
|
||||
deps.extend(provider_spec.pip_packages)
|
||||
if provider_spec.docker_image:
|
||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||
|
||||
normal_deps = []
|
||||
special_deps = []
|
||||
deps = []
|
||||
for package in package_deps.pip_packages:
|
||||
for package in deps:
|
||||
if "--no-deps" in package or "--index-url" in package:
|
||||
special_deps.append(package)
|
||||
else:
|
||||
deps.append(package)
|
||||
deps = list(set(deps))
|
||||
special_deps = list(set(special_deps))
|
||||
normal_deps.append(package)
|
||||
|
||||
return list(set(normal_deps)), list(set(special_deps))
|
||||
|
||||
|
||||
def print_pip_install_help(providers: Dict[str, List[Provider]]):
|
||||
normal_deps, special_deps = get_provider_dependencies(providers)
|
||||
|
||||
print(
|
||||
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
|
||||
)
|
||||
for special_dep in special_deps:
|
||||
print(f"\tpip install {special_dep}")
|
||||
print()
|
||||
|
||||
|
||||
def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||
docker_image = build_config.distribution_spec.docker_image or "python:3.10-slim"
|
||||
|
||||
normal_deps, special_deps = get_provider_dependencies(
|
||||
build_config.distribution_spec.providers
|
||||
)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
if build_config.image_type == ImageType.docker.value:
|
||||
script = pkg_resources.resource_filename(
|
||||
|
|
@ -95,10 +112,10 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
args = [
|
||||
script,
|
||||
build_config.name,
|
||||
package_deps.docker_image,
|
||||
docker_image,
|
||||
str(build_file_path),
|
||||
str(BUILDS_BASE_DIR / ImageType.docker.value),
|
||||
" ".join(deps),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
else:
|
||||
script = pkg_resources.resource_filename(
|
||||
|
|
@ -108,7 +125,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
script,
|
||||
build_config.name,
|
||||
str(build_file_path),
|
||||
" ".join(deps),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
if special_deps:
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
|||
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
|
||||
DOCKER_BINARY=${DOCKER_BINARY:-docker}
|
||||
DOCKER_OPTS=${DOCKER_OPTS:-}
|
||||
REPO_CONFIGS_DIR="$REPO_DIR/tmp/configs"
|
||||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
|
|
@ -65,6 +64,19 @@ RUN apt-get update && apt-get install -y \
|
|||
|
||||
EOF
|
||||
|
||||
# Add pip dependencies first since llama-stack is what will change most often
|
||||
# so we can reuse layers.
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
add_to_docker "RUN pip install --no-cache $pip_dependencies"
|
||||
fi
|
||||
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
add_to_docker "RUN pip install --no-cache $part"
|
||||
done
|
||||
fi
|
||||
|
||||
stack_mount="/app/llama-stack-source"
|
||||
models_mount="/app/llama-models-source"
|
||||
|
||||
|
|
@ -79,7 +91,16 @@ if [ -n "$LLAMA_STACK_DIR" ]; then
|
|||
# rebuild. This is just for development convenience.
|
||||
add_to_docker "RUN pip install --no-cache -e $stack_mount"
|
||||
else
|
||||
add_to_docker "RUN pip install --no-cache llama-stack"
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
add_to_docker "RUN pip install fastapi libcst"
|
||||
add_to_docker <<EOF
|
||||
RUN pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
||||
EOF
|
||||
else
|
||||
add_to_docker "RUN pip install --no-cache llama-stack"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
|
|
@ -95,16 +116,6 @@ RUN pip install --no-cache $models_mount
|
|||
EOF
|
||||
fi
|
||||
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
add_to_docker "RUN pip install --no-cache $pip_dependencies"
|
||||
fi
|
||||
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
add_to_docker "RUN pip install --no-cache $part"
|
||||
done
|
||||
fi
|
||||
|
||||
add_to_docker <<EOF
|
||||
|
||||
|
|
@ -115,8 +126,6 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
|||
|
||||
EOF
|
||||
|
||||
add_to_docker "ADD tmp/configs/$(basename "$build_file_path") ./llamastack-build.yaml"
|
||||
|
||||
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
|
||||
cat $TEMP_DIR/Dockerfile
|
||||
printf "\n"
|
||||
|
|
@ -134,11 +143,32 @@ if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
|||
DOCKER_OPTS="$DOCKER_OPTS --security-opt label=disable"
|
||||
fi
|
||||
|
||||
# Set version tag based on PyPI version
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
version_tag="test-$TEST_PYPI_VERSION"
|
||||
else
|
||||
URL="https://pypi.org/pypi/llama-stack/json"
|
||||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||
fi
|
||||
|
||||
# Add version tag to image name
|
||||
image_tag="$image_name:$version_tag"
|
||||
|
||||
# Detect platform architecture
|
||||
ARCH=$(uname -m)
|
||||
if [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||
PLATFORM="--platform linux/arm64"
|
||||
elif [ "$ARCH" = "x86_64" ]; then
|
||||
PLATFORM="--platform linux/amd64"
|
||||
else
|
||||
echo "Unsupported architecture: $ARCH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
set -x
|
||||
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
||||
$DOCKER_BINARY build $DOCKER_OPTS $PLATFORM -t $image_tag -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
||||
|
||||
# clean up tmp/configs
|
||||
rm -rf $REPO_CONFIGS_DIR
|
||||
set +x
|
||||
|
||||
echo "Success!"
|
||||
|
|
|
|||
|
|
@ -20,21 +20,17 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
|
|||
_CLIENT_CLASSES = {}
|
||||
|
||||
|
||||
async def get_client_impl(
|
||||
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
|
||||
):
|
||||
client_class = create_api_client_class(protocol, additional_protocol)
|
||||
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
||||
client_class = create_api_client_class(protocol)
|
||||
impl = client_class(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
def create_api_client_class(protocol, additional_protocol) -> Type:
|
||||
def create_api_client_class(protocol) -> Type:
|
||||
if protocol in _CLIENT_CLASSES:
|
||||
return _CLIENT_CLASSES[protocol]
|
||||
|
||||
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
|
||||
|
||||
class APIClient:
|
||||
def __init__(self, base_url: str):
|
||||
print(f"({protocol.__name__}) Connecting to {base_url}")
|
||||
|
|
@ -42,11 +38,10 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
self.routes = {}
|
||||
|
||||
# Store routes for this protocol
|
||||
for p in protocols:
|
||||
for name, method in inspect.getmembers(p):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -83,6 +78,7 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
# print(f"({protocol.__name__}) Returning {j}, type {return_type}")
|
||||
return parse_obj_as(return_type, j)
|
||||
|
||||
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||
|
|
@ -102,14 +98,15 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
data = json.loads(data)
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield parse_obj_as(return_type, json.loads(data))
|
||||
yield parse_obj_as(return_type, data)
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
print(data)
|
||||
|
||||
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||
webmethod, sig = self.routes[method_name]
|
||||
|
|
@ -141,27 +138,33 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
else:
|
||||
data.update(convert(kwargs))
|
||||
|
||||
return dict(
|
||||
ret = dict(
|
||||
method=webmethod.method or "POST",
|
||||
url=url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
params=params,
|
||||
json=data,
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if params:
|
||||
ret["params"] = params
|
||||
if data:
|
||||
ret["json"] = data
|
||||
|
||||
return ret
|
||||
|
||||
# Add protocol methods to the wrapper
|
||||
for p in protocols:
|
||||
for name, method in inspect.getmembers(p):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
|
||||
# Name the class after the protocol
|
||||
APIClient.__name__ = f"{protocol.__name__}Client"
|
||||
|
|
|
|||
|
|
@ -17,10 +17,13 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
|||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.eval_tasks import EvalTaskInput
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||
|
|
@ -30,19 +33,25 @@ RoutingKey = Union[str, List[str]]
|
|||
|
||||
|
||||
RoutableObject = Union[
|
||||
ModelDef,
|
||||
ShieldDef,
|
||||
MemoryBankDef,
|
||||
DatasetDef,
|
||||
ScoringFnDef,
|
||||
Model,
|
||||
Shield,
|
||||
MemoryBank,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
EvalTask,
|
||||
]
|
||||
|
||||
RoutableObjectWithProvider = Union[
|
||||
ModelDefWithProvider,
|
||||
ShieldDefWithProvider,
|
||||
MemoryBankDefWithProvider,
|
||||
DatasetDefWithProvider,
|
||||
ScoringFnDefWithProvider,
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
Union[
|
||||
Model,
|
||||
Shield,
|
||||
MemoryBank,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
EvalTask,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
RoutedProtocol = Union[
|
||||
|
|
@ -51,6 +60,7 @@ RoutedProtocol = Union[
|
|||
Memory,
|
||||
DatasetIO,
|
||||
Scoring,
|
||||
Eval,
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -134,6 +144,20 @@ One or more providers to use for each API. The same provider_type (e.g., meta-re
|
|||
can be instantiated multiple times (with different configs) if necessary.
|
||||
""",
|
||||
)
|
||||
metadata_store: Optional[KVStoreConfig] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the distribution registry. If not specified,
|
||||
a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: List[ModelInput] = Field(default_factory=list)
|
||||
shields: List[ShieldInput] = Field(default_factory=list)
|
||||
memory_banks: List[MemoryBankInput] = Field(default_factory=list)
|
||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
|
|
@ -43,6 +43,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
routing_table_api=Api.scoring_functions,
|
||||
router_api=Api.scoring,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.eval_tasks,
|
||||
router_api=Api.eval,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -58,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {
|
||||
"remote": remote_provider_spec(api),
|
||||
**{a.provider_type: a for a in module.available_providers()},
|
||||
}
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import inspect
|
|||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
|
@ -15,6 +17,7 @@ from llama_stack.apis.agents import Agents
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.eval_tasks import EvalTasks
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
|
|
@ -25,10 +28,16 @@ from llama_stack.apis.scoring import Scoring
|
|||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def api_protocol_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.agents: Agents,
|
||||
|
|
@ -45,16 +54,22 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
Api.scoring: Scoring,
|
||||
Api.scoring_functions: ScoringFunctions,
|
||||
Api.eval: Eval,
|
||||
Api.eval_tasks: EvalTasks,
|
||||
}
|
||||
|
||||
|
||||
def additional_protocols_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
|
||||
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||
Api.scoring: (
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
ScoringFunctions,
|
||||
Api.scoring_functions,
|
||||
),
|
||||
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -63,9 +78,14 @@ class ProviderWithSpec(Provider):
|
|||
spec: ProviderSpec
|
||||
|
||||
|
||||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
|
|
@ -94,10 +114,20 @@ async def resolve_impls(
|
|||
)
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
cprint(p.deprecation_error, "red", attrs=["bold"])
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
elif p.deprecation_warning:
|
||||
cprint(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
**(provider.dict()),
|
||||
**(provider.model_dump()),
|
||||
)
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
|
|
@ -189,6 +219,7 @@ async def resolve_impls(
|
|||
provider,
|
||||
deps,
|
||||
inner_impls,
|
||||
dist_registry,
|
||||
)
|
||||
# TODO: ugh slightly redesign this shady looking code
|
||||
if "inner-" in api_str:
|
||||
|
|
@ -237,6 +268,7 @@ async def instantiate_provider(
|
|||
provider: ProviderWithSpec,
|
||||
deps: Dict[str, Any],
|
||||
inner_impls: Dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
):
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
|
@ -249,17 +281,8 @@ async def instantiate_provider(
|
|||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider.config)
|
||||
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
protocol = protocols[provider_spec.api]
|
||||
if provider_spec.api in additional_protocols:
|
||||
_, additional_protocol = additional_protocols[provider_spec.api]
|
||||
else:
|
||||
additional_protocol = None
|
||||
args = [protocol, additional_protocol, config, deps]
|
||||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
|
@ -270,7 +293,7 @@ async def instantiate_provider(
|
|||
method = "get_routing_table_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, deps]
|
||||
args = [provider_spec.api, inner_impls, deps, dist_registry]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
|
|
@ -289,7 +312,7 @@ async def instantiate_provider(
|
|||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
and provider_spec.api in additional_protocols
|
||||
):
|
||||
additional_api, _ = additional_protocols[provider_spec.api]
|
||||
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||
check_protocol_compliance(impl, additional_api)
|
||||
|
||||
return impl
|
||||
|
|
@ -335,3 +358,29 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
raise ValueError(
|
||||
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
||||
)
|
||||
|
||||
|
||||
async def resolve_remote_stack_impls(
|
||||
config: RemoteProviderConfig,
|
||||
apis: List[str],
|
||||
) -> Dict[Api, Any]:
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
impls = {}
|
||||
for api_str in apis:
|
||||
api = Api(api_str)
|
||||
impls[api] = await get_client_impl(
|
||||
protocols[api],
|
||||
config,
|
||||
{},
|
||||
)
|
||||
if api in additional_protocols:
|
||||
_, additional_protocol, additional_api = additional_protocols[api]
|
||||
impls[additional_api] = await get_client_impl(
|
||||
additional_protocol,
|
||||
config,
|
||||
{},
|
||||
)
|
||||
|
||||
return impls
|
||||
|
|
|
|||
|
|
@ -7,8 +7,12 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
|
||||
from .routing_tables import (
|
||||
DatasetsRoutingTable,
|
||||
EvalTasksRoutingTable,
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ScoringFunctionsRoutingTable,
|
||||
|
|
@ -20,6 +24,7 @@ async def get_routing_table_impl(
|
|||
api: Api,
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> Any:
|
||||
api_to_tables = {
|
||||
"memory_banks": MemoryBanksRoutingTable,
|
||||
|
|
@ -27,12 +32,13 @@ async def get_routing_table_impl(
|
|||
"shields": ShieldsRoutingTable,
|
||||
"datasets": DatasetsRoutingTable,
|
||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||
"eval_tasks": EvalTasksRoutingTable,
|
||||
}
|
||||
|
||||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](impls_by_provider_id)
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
|
@ -40,6 +46,7 @@ async def get_routing_table_impl(
|
|||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
InferenceRouter,
|
||||
MemoryRouter,
|
||||
SafetyRouter,
|
||||
|
|
@ -52,6 +59,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
|||
"safety": SafetyRouter,
|
||||
"datasetio": DatasetIORouter,
|
||||
"scoring": ScoringRouter,
|
||||
"eval": EvalRouter,
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
|
|
|||
|
|
@ -4,16 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.datasetio.datasetio import DatasetIO
|
||||
from llama_stack.apis.memory_banks.memory_banks import BankParams
|
||||
from llama_stack.distribution.datatypes import RoutingTable
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
|
|
@ -31,8 +32,19 @@ class MemoryRouter(Memory):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
|
||||
await self.routing_table.register_memory_bank(memory_bank)
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memorybank_id: Optional[str] = None,
|
||||
) -> None:
|
||||
await self.routing_table.register_memory_bank(
|
||||
memory_bank_id,
|
||||
params,
|
||||
provider_id,
|
||||
provider_memorybank_id,
|
||||
)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
|
@ -70,12 +82,20 @@ class InferenceRouter(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.routing_table.register_model(model)
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
await self.routing_table.register_model(
|
||||
model_id, provider_model_id, provider_id, metadata
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
|
|
@ -86,7 +106,7 @@ class InferenceRouter(Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
params = dict(
|
||||
model=model,
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
|
|
@ -96,7 +116,7 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
else:
|
||||
|
|
@ -104,16 +124,16 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model=model,
|
||||
model_id=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
|
@ -127,11 +147,11 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
return await self.routing_table.get_provider_impl(model).embeddings(
|
||||
model=model,
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
)
|
||||
|
||||
|
|
@ -149,17 +169,25 @@ class SafetyRouter(Safety):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
await self.routing_table.register_shield(shield)
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
return await self.routing_table.register_shield(
|
||||
shield_id, provider_shield_id, provider_id, params
|
||||
)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
||||
shield_type=shield_type,
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
|
@ -211,16 +239,16 @@ class ScoringRouter(Scoring):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions:
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=[fn_identifier],
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
|
|
@ -232,17 +260,87 @@ class ScoringRouter(Scoring):
|
|||
)
|
||||
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions:
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions=[fn_identifier],
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
return ScoreResponse(results=res)
|
||||
|
||||
|
||||
class EvalRouter(Eval):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: AppEvalTaskConfig,
|
||||
) -> Job:
|
||||
return await self.routing_table.get_provider_impl(task_id).run_eval(
|
||||
task_id=task_id,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: EvalTaskConfig,
|
||||
) -> EvaluateResponse:
|
||||
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
|
||||
task_id=task_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
async def job_status(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
return await self.routing_table.get_provider_impl(task_id).job_status(
|
||||
task_id, job_id
|
||||
)
|
||||
|
||||
async def job_cancel(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
await self.routing_table.get_provider_impl(task_id).job_cancel(
|
||||
task_id,
|
||||
job_id,
|
||||
)
|
||||
|
||||
async def job_result(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
return await self.routing_table.get_provider_impl(task_id).job_result(
|
||||
task_id,
|
||||
job_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,13 +6,21 @@
|
|||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
|
|
@ -20,88 +28,83 @@ def get_impl_api(p: Any) -> Api:
|
|||
return p.__provider_spec__.api
|
||||
|
||||
|
||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||
# TODO: this should return the registered object for all APIs
|
||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||
|
||||
api = get_impl_api(p)
|
||||
|
||||
if obj.provider_id == "remote":
|
||||
# if this is just a passthrough, we want to let the remote
|
||||
# end actually do the registration with the correct provider
|
||||
obj = obj.model_copy(deep=True)
|
||||
obj.provider_id = ""
|
||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||
|
||||
if api == Api.inference:
|
||||
await p.register_model(obj)
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
await p.register_shield(obj)
|
||||
return await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
await p.register_memory_bank(obj)
|
||||
return await p.register_memory_bank(obj)
|
||||
elif api == Api.datasetio:
|
||||
await p.register_dataset(obj)
|
||||
return await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
await p.register_scoring_function(obj)
|
||||
return await p.register_scoring_function(obj)
|
||||
elif api == Api.eval:
|
||||
return await p.register_eval_task(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.memory:
|
||||
return await p.unregister_memory_bank(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
else:
|
||||
raise ValueError(f"Unregister not supported for {api}")
|
||||
|
||||
|
||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||
|
||||
|
||||
# TODO: this routing table maintains state in memory purely. We need to
|
||||
# add persistence to it when we add dynamic registration of objects.
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> None:
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.dist_registry = dist_registry
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.registry: Registry = {}
|
||||
|
||||
def add_objects(
|
||||
async def add_objects(
|
||||
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
||||
) -> None:
|
||||
for obj in objs:
|
||||
if obj.identifier not in self.registry:
|
||||
self.registry[obj.identifier] = []
|
||||
|
||||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
if provider_id == "remote":
|
||||
# if this is just a passthrough, we got the *WithProvider object
|
||||
# so we should just override the provider in-place
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
obj = cls(**obj.model_dump(), provider_id=provider_id)
|
||||
self.registry[obj.identifier].append(obj)
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
# Register all objects from providers
|
||||
for pid, p in self.impls_by_provider_id.items():
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
p.model_store = self
|
||||
models = await p.list_models()
|
||||
add_objects(models, pid, ModelDefWithProvider)
|
||||
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
shields = await p.list_shields()
|
||||
add_objects(shields, pid, ShieldDefWithProvider)
|
||||
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
memory_banks = await p.list_memory_banks()
|
||||
add_objects(memory_banks, pid, None)
|
||||
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
datasets = await p.list_datasets()
|
||||
add_objects(datasets, pid, DatasetDefWithProvider)
|
||||
|
||||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
|
||||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
elif api == Api.eval:
|
||||
p.eval_task_store = self
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
|
|
@ -121,42 +124,60 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
return ("Scoring", "scoring_function")
|
||||
elif isinstance(self, EvalTasksRoutingTable):
|
||||
return ("Eval", "eval_task")
|
||||
else:
|
||||
raise ValueError("Unknown routing table type")
|
||||
|
||||
if routing_key not in self.registry:
|
||||
apiname, objname = apiname_object()
|
||||
apiname, objtype = apiname_object()
|
||||
|
||||
# Get objects from disk registry
|
||||
obj = self.dist_registry.get_cached(objtype, routing_key)
|
||||
if not obj:
|
||||
provider_ids = list(self.impls_by_provider_id.keys())
|
||||
if len(provider_ids) > 1:
|
||||
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
||||
else:
|
||||
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||
raise ValueError(
|
||||
f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}."
|
||||
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
||||
)
|
||||
|
||||
objs = self.registry[routing_key]
|
||||
for obj in objs:
|
||||
if not provider_id or provider_id == obj.provider_id:
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
if not provider_id or provider_id == obj.provider_id:
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
def get_object_by_identifier(
|
||||
self, identifier: str
|
||||
async def get_object_by_identifier(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
objs = self.registry.get(identifier, [])
|
||||
if not objs:
|
||||
# Get from disk registry
|
||||
obj = await self.dist_registry.get(type, identifier)
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
# kind of ill-defined behavior here, but we'll just return the first one
|
||||
return objs[0]
|
||||
return obj
|
||||
|
||||
async def register_object(self, obj: RoutableObjectWithProvider):
|
||||
entries = self.registry.get(obj.identifier, [])
|
||||
for entry in entries:
|
||||
if entry.provider_id == obj.provider_id or not obj.provider_id:
|
||||
print(
|
||||
f"`{obj.identifier}` already registered with `{entry.provider_id}`"
|
||||
)
|
||||
return
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
await unregister_object_from_provider(
|
||||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
|
||||
# if provider_id is not specified, we'll pick an arbitrary one from existing entries
|
||||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
# Get existing objects from registry
|
||||
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
|
||||
|
||||
# Check for existing registration
|
||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||
print(
|
||||
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
|
||||
)
|
||||
return existing_obj
|
||||
|
||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
|
|
@ -165,90 +186,252 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
await register_object_with_provider(obj, p)
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
return registered_obj
|
||||
|
||||
if obj.identifier not in self.registry:
|
||||
self.registry[obj.identifier] = []
|
||||
self.registry[obj.identifier].append(obj)
|
||||
else:
|
||||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
||||
# TODO: persist this to a store
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
async def list_models(self) -> List[Model]:
|
||||
return await self.get_all_with_type("model")
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
async def get_model(self, identifier: str) -> Optional[Model]:
|
||||
return await self.get_object_by_identifier("model", identifier)
|
||||
|
||||
async def register_model(self, model: ModelDefWithProvider) -> None:
|
||||
await self.register_object(model)
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Model:
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this model
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
async def list_shields(self) -> List[Shield]:
|
||||
return await self.get_all_with_type(ResourceType.shield.value)
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||
return self.get_object_by_identifier(shield_type)
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||
return await self.get_object_by_identifier("shield", identifier)
|
||||
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
if provider_shield_id is None:
|
||||
provider_shield_id = shield_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
if params is None:
|
||||
params = {}
|
||||
shield = Shield(
|
||||
identifier=shield_id,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
params=params,
|
||||
)
|
||||
await self.register_object(shield)
|
||||
return shield
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
return await self.get_all_with_type(ResourceType.memory_bank.value)
|
||||
|
||||
async def get_memory_bank(
|
||||
self, identifier: str
|
||||
) -> Optional[MemoryBankDefWithProvider]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
|
||||
|
||||
async def register_memory_bank(
|
||||
self, memory_bank: MemoryBankDefWithProvider
|
||||
) -> None:
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank:
|
||||
if provider_memory_bank_id is None:
|
||||
provider_memory_bank_id = memory_bank_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
memory_bank = parse_obj_as(
|
||||
MemoryBank,
|
||||
{
|
||||
"identifier": memory_bank_id,
|
||||
"type": ResourceType.memory_bank.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": provider_memory_bank_id,
|
||||
**params.model_dump(),
|
||||
},
|
||||
)
|
||||
await self.register_object(memory_bank)
|
||||
return memory_bank
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||
if existing_bank is None:
|
||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
||||
await self.unregister_object(existing_bank)
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> List[DatasetDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
async def list_datasets(self) -> List[Dataset]:
|
||||
return await self.get_all_with_type(ResourceType.dataset.value)
|
||||
|
||||
async def get_dataset(
|
||||
self, dataset_identifier: str
|
||||
) -> Optional[DatasetDefWithProvider]:
|
||||
return self.get_object_by_identifier(dataset_identifier)
|
||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||
|
||||
async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None:
|
||||
await self.register_object(dataset_def)
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_schema: Dict[str, ParamType],
|
||||
url: URL,
|
||||
provider_dataset_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if provider_dataset_id is None:
|
||||
provider_dataset_id = dataset_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this dataset
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
dataset = Dataset(
|
||||
identifier=dataset_id,
|
||||
provider_resource_id=provider_dataset_id,
|
||||
provider_id=provider_id,
|
||||
dataset_schema=dataset_schema,
|
||||
url=url,
|
||||
metadata=metadata,
|
||||
)
|
||||
await self.register_object(dataset)
|
||||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
return await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
) -> Optional[ScoringFnDefWithProvider]:
|
||||
return self.get_object_by_identifier(name)
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFnDefWithProvider
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
) -> None:
|
||||
await self.register_object(function_def)
|
||||
if provider_scoring_fn_id is None:
|
||||
provider_scoring_fn_id = scoring_fn_id
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
scoring_fn = ScoringFn(
|
||||
identifier=scoring_fn_id,
|
||||
description=description,
|
||||
return_type=return_type,
|
||||
provider_resource_id=provider_scoring_fn_id,
|
||||
provider_id=provider_id,
|
||||
params=params,
|
||||
)
|
||||
scoring_fn.provider_id = provider_id
|
||||
await self.register_object(scoring_fn)
|
||||
|
||||
|
||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||
async def list_eval_tasks(self) -> List[EvalTask]:
|
||||
return await self.get_all_with_type(ResourceType.eval_task.value)
|
||||
|
||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
|
||||
return await self.get_object_by_identifier("eval_task", name)
|
||||
|
||||
async def register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
provider_eval_task_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
if provider_eval_task_id is None:
|
||||
provider_eval_task_id = eval_task_id
|
||||
eval_task = EvalTask(
|
||||
identifier=eval_task_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
metadata=metadata,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_eval_task_id,
|
||||
)
|
||||
await self.register_object(eval_task)
|
||||
|
|
|
|||
|
|
@ -8,8 +8,12 @@ import asyncio
|
|||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from ssl import SSLError
|
||||
|
|
@ -26,10 +30,7 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
|
|
@ -38,16 +39,26 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
start_trace,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import construct_stack
|
||||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
log = file if hasattr(file, "write") else sys.stderr
|
||||
traceback.print_stack(file=log)
|
||||
log.write(warnings.formatwarning(message, category, filename, lineno, line))
|
||||
|
||||
|
||||
if os.environ.get("LLAMA_STACK_TRACE_WARNINGS"):
|
||||
warnings.showwarning = warn_with_traceback
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
data = data.model_dump_json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
|
|
@ -184,15 +195,6 @@ async def lifespan(app: FastAPI):
|
|||
await impl.shutdown()
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
async def endpoint(request: Request):
|
||||
return await passthrough(request, downstream_url, downstream_headers)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
|
|
@ -206,7 +208,8 @@ async def maybe_await(value):
|
|||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in await event_gen:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
|
|
@ -226,7 +229,6 @@ async def sse_generator(event_gen):
|
|||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
|
|
@ -269,17 +271,74 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
return endpoint
|
||||
|
||||
|
||||
class EnvVarError(Exception):
|
||||
def __init__(self, var_name: str, path: str = ""):
|
||||
self.var_name = var_name
|
||||
self.path = path
|
||||
super().__init__(
|
||||
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
|
||||
)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
for k, v in config.items():
|
||||
try:
|
||||
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
||||
elif isinstance(config, list):
|
||||
result = []
|
||||
for i, v in enumerate(config):
|
||||
try:
|
||||
result.append(replace_env_vars(v, f"{path}[{i}]"))
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
||||
elif isinstance(config, str):
|
||||
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
|
||||
|
||||
def get_env_var(match):
|
||||
env_var = match.group(1)
|
||||
default_val = match.group(2)
|
||||
|
||||
value = os.environ.get(env_var)
|
||||
if not value:
|
||||
if default_val is None:
|
||||
raise EnvVarError(env_var, path)
|
||||
else:
|
||||
value = default_val
|
||||
|
||||
return value
|
||||
|
||||
try:
|
||||
return re.sub(pattern, get_env_var, config)
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def main(
|
||||
yaml_config: str = "llamastack-run.yaml",
|
||||
port: int = 5000,
|
||||
disable_ipv6: bool = False,
|
||||
):
|
||||
with open(yaml_config, "r") as fp:
|
||||
config = StackRunConfig(**yaml.safe_load(fp))
|
||||
config = replace_env_vars(yaml.safe_load(fp))
|
||||
config = StackRunConfig(**config)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError:
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
|
|
@ -303,28 +362,19 @@ def main(
|
|||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
if is_passthrough(impl.__provider_spec__):
|
||||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
else:
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(
|
||||
f"Could not find method {endpoint.name} on {impl}!!"
|
||||
)
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
)
|
||||
)
|
||||
|
||||
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
|
||||
for endpoint in endpoints:
|
||||
|
|
|
|||
107
llama_stack/distribution/stack.py
Normal file
107
llama_stack/distribution/stack.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.batch_inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from llama_stack.apis.post_training import * # noqa: F403
|
||||
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
MemoryBanks,
|
||||
Inference,
|
||||
BatchInference,
|
||||
Agents,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
Telemetry,
|
||||
PostTraining,
|
||||
Memory,
|
||||
Eval,
|
||||
EvalTasks,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
DatasetIO,
|
||||
Models,
|
||||
Shields,
|
||||
Inspect,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
RESOURCES = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
("shields", Api.shields, "register_shield", "list_shields"),
|
||||
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||
(
|
||||
"scoring_fns",
|
||||
Api.scoring_functions,
|
||||
"register_scoring_function",
|
||||
"list_scoring_functions",
|
||||
),
|
||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||
]
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
if api not in impls:
|
||||
continue
|
||||
|
||||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
await method(**obj.model_dump())
|
||||
|
||||
method = getattr(impls[api], list_method)
|
||||
for obj in await method():
|
||||
print(
|
||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||
)
|
||||
|
||||
print("")
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(
|
||||
run_config.metadata_store, run_config.image_name
|
||||
)
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(), dist_registry
|
||||
)
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
|
@ -10,6 +10,8 @@ DOCKER_BINARY=${DOCKER_BINARY:-docker}
|
|||
DOCKER_OPTS=${DOCKER_OPTS:-}
|
||||
LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
PYPI_VERSION=${PYPI_VERSION:-}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
|
|
@ -54,11 +56,18 @@ if [ -n "$LLAMA_CHECKPOINT_DIR" ]; then
|
|||
DOCKER_OPTS="$DOCKER_OPTS --gpus=all"
|
||||
fi
|
||||
|
||||
version_tag="latest"
|
||||
if [ -n "$PYPI_VERSION" ]; then
|
||||
version_tag="$PYPI_VERSION"
|
||||
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
version_tag="test-$TEST_PYPI_VERSION"
|
||||
fi
|
||||
|
||||
$DOCKER_BINARY run $DOCKER_OPTS -it \
|
||||
-p $port:$port \
|
||||
-v "$yaml_config:/app/config.yaml" \
|
||||
$mounts \
|
||||
$docker_image \
|
||||
$docker_image:$version_tag \
|
||||
python -m llama_stack.distribution.server.server \
|
||||
--yaml_config /app/config.yaml \
|
||||
--port $port "$@"
|
||||
|
|
|
|||
7
llama_stack/distribution/store/__init__.py
Normal file
7
llama_stack/distribution/store/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .registry import * # noqa: F401 F403
|
||||
221
llama_stack/distribution/store/registry.py
Normal file
221
llama_stack/distribution/store/registry.py
Normal file
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, List, Optional, Protocol, Tuple
|
||||
|
||||
import pydantic
|
||||
|
||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
from llama_stack.providers.utils.kvstore import (
|
||||
KVStore,
|
||||
kvstore_impl,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
class DistributionRegistry(Protocol):
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
||||
|
||||
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
||||
|
||||
async def update(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider: ...
|
||||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
||||
|
||||
async def delete(self, type: str, identifier: str) -> None: ...
|
||||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v2"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
def _get_registry_key_range() -> Tuple[str, str]:
|
||||
"""Returns the start and end keys for the registry range query."""
|
||||
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
|
||||
return start_key, f"{start_key}\xff"
|
||||
|
||||
|
||||
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
|
||||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||
all_objects = []
|
||||
for value in values:
|
||||
obj = pydantic.parse_obj_as(
|
||||
RoutableObjectWithProvider,
|
||||
json.loads(value),
|
||||
)
|
||||
all_objects.append(obj)
|
||||
return all_objects
|
||||
|
||||
|
||||
class DiskDistributionRegistry(DistributionRegistry):
|
||||
def __init__(self, kvstore: KVStore):
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
def get_cached(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
# Disk registry does not have a cache
|
||||
raise NotImplementedError("Disk registry does not have a cache")
|
||||
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||
start_key, end_key = _get_registry_key_range()
|
||||
values = await self.kvstore.range(start_key, end_key)
|
||||
return _parse_registry_values(values)
|
||||
|
||||
async def get(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
json_str = await self.kvstore.get(
|
||||
KEY_FORMAT.format(type=type, identifier=identifier)
|
||||
)
|
||||
if not json_str:
|
||||
return None
|
||||
|
||||
objects_data = json.loads(json_str)
|
||||
# Return only the first object if any exist
|
||||
if objects_data:
|
||||
return pydantic.parse_obj_as(
|
||||
RoutableObjectWithProvider,
|
||||
json.loads(objects_data),
|
||||
)
|
||||
return None
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||
obj.model_dump_json(),
|
||||
)
|
||||
return obj
|
||||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||
existing_obj = await self.get(obj.type, obj.identifier)
|
||||
# dont register if the object's providerid already exists
|
||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||
return False
|
||||
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||
obj.model_dump_json(),
|
||||
)
|
||||
return True
|
||||
|
||||
async def delete(self, type: str, identifier: str) -> None:
|
||||
await self.kvstore.delete(KEY_FORMAT.format(type=type, identifier=identifier))
|
||||
|
||||
|
||||
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||
def __init__(self, kvstore: KVStore):
|
||||
super().__init__(kvstore)
|
||||
self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {}
|
||||
self._initialized = False
|
||||
self._initialize_lock = asyncio.Lock()
|
||||
self._cache_lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _locked_cache(self):
|
||||
"""Context manager for safely accessing the cache with a lock."""
|
||||
async with self._cache_lock:
|
||||
yield self.cache
|
||||
|
||||
async def _ensure_initialized(self):
|
||||
"""Ensures the registry is initialized before operations."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
async with self._initialize_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
start_key, end_key = _get_registry_key_range()
|
||||
values = await self.kvstore.range(start_key, end_key)
|
||||
objects = _parse_registry_values(values)
|
||||
|
||||
async with self._locked_cache() as cache:
|
||||
for obj in objects:
|
||||
cache_key = (obj.type, obj.identifier)
|
||||
cache[cache_key] = obj
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await self._ensure_initialized()
|
||||
|
||||
def get_cached(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
return self.cache.get((type, identifier), None)
|
||||
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||
await self._ensure_initialized()
|
||||
async with self._locked_cache() as cache:
|
||||
return list(cache.values())
|
||||
|
||||
async def get(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
await self._ensure_initialized()
|
||||
cache_key = (type, identifier)
|
||||
|
||||
async with self._locked_cache() as cache:
|
||||
return cache.get(cache_key, None)
|
||||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||
await self._ensure_initialized()
|
||||
success = await super().register(obj)
|
||||
|
||||
if success:
|
||||
cache_key = (obj.type, obj.identifier)
|
||||
async with self._locked_cache() as cache:
|
||||
cache[cache_key] = obj
|
||||
|
||||
return success
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await super().update(obj)
|
||||
cache_key = (obj.type, obj.identifier)
|
||||
async with self._locked_cache() as cache:
|
||||
cache[cache_key] = obj
|
||||
return obj
|
||||
|
||||
async def delete(self, type: str, identifier: str) -> None:
|
||||
await super().delete(type, identifier)
|
||||
cache_key = (type, identifier)
|
||||
async with self._locked_cache() as cache:
|
||||
if cache_key in cache:
|
||||
del cache[cache_key]
|
||||
|
||||
|
||||
async def create_dist_registry(
|
||||
metadata_store: Optional[KVStoreConfig],
|
||||
image_name: str,
|
||||
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||
# instantiate kvstore for storing and retrieving distribution metadata
|
||||
if metadata_store:
|
||||
dist_kvstore = await kvstore_impl(metadata_store)
|
||||
else:
|
||||
dist_kvstore = await kvstore_impl(
|
||||
SqliteKVStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
|
||||
)
|
||||
)
|
||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||
await dist_registry.initialize()
|
||||
return dist_registry, dist_kvstore
|
||||
215
llama_stack/distribution/store/tests/test_registry.py
Normal file
215
llama_stack/distribution/store/tests/test_registry.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from llama_stack.distribution.store import * # noqa F403
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.memory_banks import VectorMemoryBank
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
from llama_stack.distribution.datatypes import * # noqa F403
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
|
||||
if os.path.exists(config.db_path):
|
||||
os.remove(config.db_path)
|
||||
return config
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def registry(config):
|
||||
registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def cached_registry(config):
|
||||
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_bank():
|
||||
return VectorMemoryBank(
|
||||
identifier="test_bank",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
provider_resource_id="test_bank",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model():
|
||||
return Model(
|
||||
identifier="test_model",
|
||||
provider_resource_id="test_model",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_initialization(registry):
|
||||
# Test empty registry
|
||||
results = await registry.get("nonexistent", "nonexistent")
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_registration(registry, sample_bank, sample_model):
|
||||
print(f"Registering {sample_bank}")
|
||||
await registry.register(sample_bank)
|
||||
print(f"Registering {sample_model}")
|
||||
await registry.register(sample_model)
|
||||
print("Getting bank")
|
||||
results = await registry.get("memory_bank", "test_bank")
|
||||
assert len(results) == 1
|
||||
result_bank = results[0]
|
||||
assert result_bank.identifier == sample_bank.identifier
|
||||
assert result_bank.embedding_model == sample_bank.embedding_model
|
||||
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
|
||||
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
||||
assert result_bank.provider_id == sample_bank.provider_id
|
||||
|
||||
results = await registry.get("model", "test_model")
|
||||
assert len(results) == 1
|
||||
result_model = results[0]
|
||||
assert result_model.identifier == sample_model.identifier
|
||||
assert result_model.provider_id == sample_model.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_initialization(config, sample_bank, sample_model):
|
||||
# First populate the disk registry
|
||||
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await disk_registry.initialize()
|
||||
await disk_registry.register(sample_bank)
|
||||
await disk_registry.register(sample_model)
|
||||
|
||||
# Test cached version loads from disk
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
results = await cached_registry.get("memory_bank", "test_bank")
|
||||
assert len(results) == 1
|
||||
result_bank = results[0]
|
||||
assert result_bank.identifier == sample_bank.identifier
|
||||
assert result_bank.embedding_model == sample_bank.embedding_model
|
||||
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
|
||||
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
||||
assert result_bank.provider_id == sample_bank.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_updates(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
new_bank = VectorMemoryBank(
|
||||
identifier="test_bank_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=256,
|
||||
overlap_size_in_tokens=32,
|
||||
provider_resource_id="test_bank_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_registry.register(new_bank)
|
||||
|
||||
# Verify in cache
|
||||
results = await cached_registry.get("memory_bank", "test_bank_2")
|
||||
assert len(results) == 1
|
||||
result_bank = results[0]
|
||||
assert result_bank.identifier == new_bank.identifier
|
||||
assert result_bank.provider_id == new_bank.provider_id
|
||||
|
||||
# Verify persisted to disk
|
||||
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await new_registry.initialize()
|
||||
results = await new_registry.get("memory_bank", "test_bank_2")
|
||||
assert len(results) == 1
|
||||
result_bank = results[0]
|
||||
assert result_bank.identifier == new_bank.identifier
|
||||
assert result_bank.provider_id == new_bank.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_provider_registration(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
original_bank = VectorMemoryBank(
|
||||
identifier="test_bank_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=256,
|
||||
overlap_size_in_tokens=32,
|
||||
provider_resource_id="test_bank_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_registry.register(original_bank)
|
||||
|
||||
duplicate_bank = VectorMemoryBank(
|
||||
identifier="test_bank_2",
|
||||
embedding_model="different-model",
|
||||
chunk_size_in_tokens=128,
|
||||
overlap_size_in_tokens=16,
|
||||
provider_resource_id="test_bank_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
await cached_registry.register(duplicate_bank)
|
||||
|
||||
results = await cached_registry.get("memory_bank", "test_bank_2")
|
||||
assert len(results) == 1 # Still only one result
|
||||
assert (
|
||||
results[0].embedding_model == original_bank.embedding_model
|
||||
) # Original values preserved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_objects(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
# Create multiple test banks
|
||||
test_banks = [
|
||||
VectorMemoryBank(
|
||||
identifier=f"test_bank_{i}",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=256,
|
||||
overlap_size_in_tokens=32,
|
||||
provider_resource_id=f"test_bank_{i}",
|
||||
provider_id=f"provider_{i}",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Register all banks
|
||||
for bank in test_banks:
|
||||
await cached_registry.register(bank)
|
||||
|
||||
# Test get_all retrieval
|
||||
all_results = await cached_registry.get_all()
|
||||
assert len(all_results) == 3
|
||||
|
||||
# Verify each bank was stored correctly
|
||||
for original_bank in test_banks:
|
||||
matching_banks = [
|
||||
b for b in all_results if b.identifier == original_bank.identifier
|
||||
]
|
||||
assert len(matching_banks) == 1
|
||||
stored_bank = matching_banks[0]
|
||||
assert stored_bank.embedding_model == original_bank.embedding_model
|
||||
assert stored_bank.provider_id == original_bank.provider_id
|
||||
assert stored_bank.chunk_size_in_tokens == original_bank.chunk_size_in_tokens
|
||||
assert (
|
||||
stored_bank.overlap_size_in_tokens == original_bank.overlap_size_in_tokens
|
||||
)
|
||||
|
|
@ -10,4 +10,5 @@ from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
|||
|
||||
|
||||
def model_local_dir(descriptor: str) -> str:
|
||||
return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)
|
||||
path = os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)
|
||||
return path.replace(":", "-")
|
||||
|
|
|
|||
|
|
@ -1,187 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
FIREWORKS_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
||||
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
|
||||
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
||||
"Llama3.2-11B-Vision-Instruct": "llama-v3p2-11b-vision-instruct",
|
||||
"Llama3.2-90B-Vision-Instruct": "llama-v3p2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
client = Fireworks(api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_completion(request, client)
|
||||
else:
|
||||
return await self._nonstream_completion(request, client)
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest, client: Fireworks
|
||||
) -> CompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await client.completion.acreate(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_completion(
|
||||
self, request: CompletionRequest, client: Fireworks
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
stream = client.completion.acreate(**params)
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
yield chunk
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
client = Fireworks(api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await client.completion.acreate(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
stream = client.completion.acreate(**params)
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request) -> dict:
|
||||
prompt = ""
|
||||
if type(request) == ChatCompletionRequest:
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
elif type(request) == CompletionRequest:
|
||||
prompt = completion_request_to_prompt(request, self.formatter)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
options = get_sampling_options(request.sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
options["response_format"] = {
|
||||
"type": "grammar",
|
||||
"grammar": fmt.bnf,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": prompt,
|
||||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BedrockSafetyConfig(BaseModel):
|
||||
"""Configuration information for a guardrail that you want to use in the request."""
|
||||
|
||||
aws_profile: str = Field(
|
||||
default="default",
|
||||
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
||||
)
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TogetherProviderDataValidator(BaseModel):
|
||||
together_api_key: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherSafetyConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Together AI API Key (default for the distribution, if any)",
|
||||
)
|
||||
|
|
@ -1,101 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from together import Together
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .config import TogetherSafetyConfig
|
||||
|
||||
|
||||
TOGETHER_SHIELD_MODEL_MAP = {
|
||||
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
}
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier=ShieldType.llama_guard.value,
|
||||
type=ShieldType.llama_guard.value,
|
||||
params={},
|
||||
)
|
||||
]
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
model = shield_def.params.get("model", "llama_guard")
|
||||
if model not in TOGETHER_SHIELD_MODEL_MAP:
|
||||
raise ValueError(f"Unsupported safety model: {model}")
|
||||
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
# messages can have role assistant or user
|
||||
api_messages = []
|
||||
for message in messages:
|
||||
if message.role in (Role.user.value, Role.assistant.value):
|
||||
api_messages.append({"role": message.role, "content": message.content})
|
||||
|
||||
violation = await get_safety_response(
|
||||
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
|
||||
async def get_safety_response(
|
||||
api_key: str, model_name: str, messages: List[Dict[str, str]]
|
||||
) -> Optional[SafetyViolation]:
|
||||
client = Together(api_key=api_key)
|
||||
response = client.chat.completions.create(messages=messages, model=model_name)
|
||||
if len(response.choices) == 0:
|
||||
return None
|
||||
|
||||
response_text = response.choices[0].message.content
|
||||
if response_text == "safe":
|
||||
return None
|
||||
|
||||
parts = response_text.split("\n")
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
if parts[0] == "unsafe":
|
||||
return SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata={"violation_type": parts[1]},
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
@ -6,15 +6,17 @@
|
|||
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.datasets import DatasetDef
|
||||
from llama_stack.apis.memory_banks import MemoryBankDef
|
||||
from llama_stack.apis.models import ModelDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.shields import ShieldDef
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -34,39 +36,42 @@ class Api(Enum):
|
|||
memory_banks = "memory_banks"
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
eval_tasks = "eval_tasks"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
||||
|
||||
class ModelsProtocolPrivate(Protocol):
|
||||
async def list_models(self) -> List[ModelDef]: ...
|
||||
async def register_model(self, model: Model) -> None: ...
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None: ...
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
||||
|
||||
class ShieldsProtocolPrivate(Protocol):
|
||||
async def list_shields(self) -> List[ShieldDef]: ...
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None: ...
|
||||
async def register_shield(self, shield: Shield) -> None: ...
|
||||
|
||||
|
||||
class MemoryBanksProtocolPrivate(Protocol):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||
|
||||
|
||||
class DatasetsProtocolPrivate(Protocol):
|
||||
async def list_datasets(self) -> List[DatasetDef]: ...
|
||||
|
||||
async def register_dataset(self, dataset_def: DatasetDef) -> None: ...
|
||||
async def register_dataset(self, dataset: Dataset) -> None: ...
|
||||
|
||||
|
||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]: ...
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
|
||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
|
||||
|
||||
|
||||
class EvalTasksProtocolPrivate(Protocol):
|
||||
async def register_eval_task(self, eval_task: EvalTask) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -81,6 +86,14 @@ class ProviderSpec(BaseModel):
|
|||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
deprecation_warning: Optional[str] = Field(
|
||||
default=None,
|
||||
description="If this provider is deprecated, specify the warning message here",
|
||||
)
|
||||
deprecation_error: Optional[str] = Field(
|
||||
default=None,
|
||||
description="If this provider is deprecated and does NOT work, specify the error message here",
|
||||
)
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: List[str] = Field(default_factory=list)
|
||||
|
|
@ -90,6 +103,7 @@ class RoutingTable(Protocol):
|
|||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
# TODO: this can now be inlined into RemoteProviderSpec
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
|
|
@ -145,21 +159,27 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int
|
||||
port: Optional[int] = None
|
||||
protocol: str = "http"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
if self.port is None:
|
||||
return f"{self.protocol}://{self.host}"
|
||||
return f"{self.protocol}://{self.host}:{self.port}"
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str) -> "RemoteProviderConfig":
|
||||
parsed = urlparse(url)
|
||||
return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
adapter: AdapterSpec = Field(
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
API responses, specify the adapter here.
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
@ -169,38 +189,21 @@ as being "Llama Stack compatible"
|
|||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
if self.adapter:
|
||||
return self.adapter.module
|
||||
return "llama_stack.distribution.client"
|
||||
return self.adapter.module
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.pip_packages
|
||||
return []
|
||||
return self.adapter.pip_packages
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> Optional[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.provider_data_validator
|
||||
return None
|
||||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
|
||||
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
|
||||
|
||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,120 +0,0 @@
|
|||
# LocalInference
|
||||
|
||||
LocalInference provides a local inference implementation powered by [executorch](https://github.com/pytorch/executorch/).
|
||||
|
||||
Llama Stack currently supports on-device inference for iOS with Android coming soon. You can run on-device inference on Android today using [executorch](https://github.com/pytorch/executorch/tree/main/examples/demo-apps/android/LlamaDemo), PyTorch’s on-device inference library.
|
||||
|
||||
## Installation
|
||||
|
||||
We're working on making LocalInference easier to set up. For now, you'll need to import it via `.xcframework`:
|
||||
|
||||
1. Clone the executorch submodule in this repo and its dependencies: `git submodule update --init --recursive`
|
||||
1. Install [Cmake](https://cmake.org/) for the executorch build`
|
||||
1. Drag `LocalInference.xcodeproj` into your project
|
||||
1. Add `LocalInference` as a framework in your app target
|
||||
1. Add a package dependency on https://github.com/pytorch/executorch (branch latest)
|
||||
1. Add all the kernels / backends from executorch (but not exectuorch itself!) as frameworks in your app target:
|
||||
- backend_coreml
|
||||
- backend_mps
|
||||
- backend_xnnpack
|
||||
- kernels_custom
|
||||
- kernels_optimized
|
||||
- kernels_portable
|
||||
- kernels_quantized
|
||||
1. In "Build Settings" > "Other Linker Flags" > "Any iOS Simulator SDK", add:
|
||||
```
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
|
||||
```
|
||||
|
||||
1. In "Build Settings" > "Other Linker Flags" > "Any iOS SDK", add:
|
||||
|
||||
```
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
|
||||
```
|
||||
|
||||
## Preparing a model
|
||||
|
||||
1. Prepare a `.pte` file [following the executorch docs](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md#step-2-prepare-model)
|
||||
2. Bundle the `.pte` and `tokenizer.model` file into your app
|
||||
|
||||
We now support models quantized using SpinQuant and QAT-LoRA which offer a significant performance boost (demo app on iPhone 13 Pro):
|
||||
|
||||
|
||||
| Llama 3.2 1B | Tokens / Second (total) | | Time-to-First-Token (sec) | |
|
||||
| :---- | :---- | :---- | :---- | :---- |
|
||||
| | Haiku | Paragraph | Haiku | Paragraph |
|
||||
| BF16 | 2.2 | 2.5 | 2.3 | 1.9 |
|
||||
| QAT+LoRA | 7.1 | 3.3 | 0.37 | 0.24 |
|
||||
| SpinQuant | 10.1 | 5.2 | 0.2 | 0.2 |
|
||||
|
||||
|
||||
## Using LocalInference
|
||||
|
||||
1. Instantiate LocalInference with a DispatchQueue. Optionally, pass it into your agents service:
|
||||
|
||||
```swift
|
||||
init () {
|
||||
runnerQueue = DispatchQueue(label: "org.meta.llamastack")
|
||||
inferenceService = LocalInferenceService(queue: runnerQueue)
|
||||
agentsService = LocalAgentsService(inference: inferenceService)
|
||||
}
|
||||
```
|
||||
|
||||
2. Before making any inference calls, load your model from your bundle:
|
||||
|
||||
```swift
|
||||
let mainBundle = Bundle.main
|
||||
inferenceService.loadModel(
|
||||
modelPath: mainBundle.url(forResource: "llama32_1b_spinquant", withExtension: "pte"),
|
||||
tokenizerPath: mainBundle.url(forResource: "tokenizer", withExtension: "model"),
|
||||
completion: {_ in } // use to handle load failures
|
||||
)
|
||||
```
|
||||
|
||||
3. Make inference calls (or agents calls) as you normally would with LlamaStack:
|
||||
|
||||
```
|
||||
for await chunk in try await agentsService.initAndCreateTurn(
|
||||
messages: [
|
||||
.UserMessage(Components.Schemas.UserMessage(
|
||||
content: .case1("Call functions as needed to handle any actions in the following text:\n\n" + text),
|
||||
role: .user))
|
||||
]
|
||||
) {
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you receive errors like "missing package product" or "invalid checksum", try cleaning the build folder and resetting the Swift package cache:
|
||||
|
||||
(Opt+Click) Product > Clean Build Folder Immediately
|
||||
|
||||
```
|
||||
rm -rf \
|
||||
~/Library/org.swift.swiftpm \
|
||||
~/Library/Caches/org.swift.swiftpm \
|
||||
~/Library/Caches/com.apple.dt.Xcode \
|
||||
~/Library/Developer/Xcode/DerivedData
|
||||
```
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import SafetyConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SafetyConfig, deps):
|
||||
from .safety import MetaReferenceSafetyImpl
|
||||
|
||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceSafetyImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||
from pydantic import BaseModel
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
|
||||
# TODO: clean this up; just remove this type completely
|
||||
class ShieldResponse(BaseModel):
|
||||
is_violation: bool
|
||||
violation_type: Optional[str] = None
|
||||
violation_return_message: Optional[str] = None
|
||||
|
||||
|
||||
# TODO: this is a caller / agent concern
|
||||
class OnViolationAction(Enum):
|
||||
IGNORE = 0
|
||||
WARN = 1
|
||||
RAISE = 2
|
||||
|
||||
|
||||
class ShieldBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
self.on_violation_action = on_violation_action
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def message_content_as_str(message: Message) -> str:
|
||||
return interleaved_text_media_as_str(message.content)
|
||||
|
||||
|
||||
class TextShield(ShieldBase):
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return "\n".join([message_content_as_str(m) for m in messages])
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
text = self.convert_messages_to_text(messages)
|
||||
return await self.run_impl(text)
|
||||
|
||||
@abstractmethod
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_models.sku_list import CoreModelId, safety_models
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class PromptGuardType(Enum):
|
||||
injection = "injection"
|
||||
jailbreak = "jailbreak"
|
||||
|
||||
|
||||
class LlamaGuardShieldConfig(BaseModel):
|
||||
model: str = "Llama-Guard-3-1B"
|
||||
excluded_categories: List[str] = []
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
m.descriptor()
|
||||
for m in safety_models()
|
||||
if (
|
||||
m.core_model_id
|
||||
in {
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
}
|
||||
)
|
||||
]
|
||||
if model not in permitted_models:
|
||||
raise ValueError(
|
||||
f"Invalid model: {model}. Must be one of {permitted_models}"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
class SafetyConfig(BaseModel):
|
||||
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
||||
enable_prompt_guard: Optional[bool] = False
|
||||
|
|
@ -1,145 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import auto, Enum
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from termcolor import cprint
|
||||
|
||||
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
||||
|
||||
|
||||
class PromptGuardShield(TextShield):
|
||||
class Mode(Enum):
|
||||
INJECTION = auto()
|
||||
JAILBREAK = auto()
|
||||
|
||||
_instances = {}
|
||||
_model_cache = None
|
||||
|
||||
@staticmethod
|
||||
def instance(
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
) -> "PromptGuardShield":
|
||||
action_value = on_violation_action.value
|
||||
key = (model_dir, threshold, temperature, mode, action_value)
|
||||
if key not in PromptGuardShield._instances:
|
||||
PromptGuardShield._instances[key] = PromptGuardShield(
|
||||
model_dir=model_dir,
|
||||
threshold=threshold,
|
||||
temperature=temperature,
|
||||
mode=mode,
|
||||
on_violation_action=on_violation_action,
|
||||
)
|
||||
return PromptGuardShield._instances[key]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
assert (
|
||||
model_dir is not None
|
||||
), "Must provide a model directory for prompt injection shield"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
self.device = "cuda"
|
||||
if PromptGuardShield._model_cache is None:
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
# load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_dir, device_map=self.device
|
||||
)
|
||||
PromptGuardShield._model_cache = (tokenizer, model)
|
||||
|
||||
self.tokenizer, self.model = PromptGuardShield._model_cache
|
||||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
self.mode = mode
|
||||
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return message_content_as_str(messages[-1])
|
||||
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
# run model on messages and return response
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs[0]
|
||||
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
||||
score_embedded = probabilities[0, 1].item()
|
||||
score_malicious = probabilities[0, 2].item()
|
||||
cprint(
|
||||
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
||||
color="magenta",
|
||||
)
|
||||
|
||||
if self.mode == self.Mode.INJECTION and (
|
||||
score_embedded + score_malicious > self.threshold
|
||||
):
|
||||
return ShieldResponse(
|
||||
is_violation=True,
|
||||
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
)
|
||||
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
||||
return ShieldResponse(
|
||||
is_violation=True,
|
||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
)
|
||||
|
||||
return ShieldResponse(
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
|
||||
class JailbreakShield(PromptGuardShield):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
threshold=threshold,
|
||||
temperature=temperature,
|
||||
mode=PromptGuardShield.Mode.JAILBREAK,
|
||||
on_violation_action=on_violation_action,
|
||||
)
|
||||
|
||||
|
||||
class InjectionShield(PromptGuardShield):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
threshold=threshold,
|
||||
temperature=temperature,
|
||||
mode=PromptGuardShield.Mode.INJECTION,
|
||||
on_violation_action=on_violation_action,
|
||||
)
|
||||
|
|
@ -1,112 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .base import OnViolationAction, ShieldBase
|
||||
from .config import SafetyConfig
|
||||
from .llama_guard import LlamaGuardShield
|
||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
|
||||
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
self.available_shields = []
|
||||
if config.llama_guard_shield:
|
||||
self.available_shields.append(ShieldType.llama_guard.value)
|
||||
if config.enable_prompt_guard:
|
||||
self.available_shields.append(ShieldType.prompt_guard.value)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if self.config.enable_prompt_guard:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
_ = PromptGuardShield.instance(model_dir)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier=shield_type,
|
||||
type=shield_type,
|
||||
params={},
|
||||
)
|
||||
for shield_type in self.available_shields
|
||||
]
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
shield = self.get_shield_impl(shield_def)
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
||||
res = await shield.run(messages)
|
||||
violation = None
|
||||
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
|
||||
violation = SafetyViolation(
|
||||
violation_level=(
|
||||
ViolationLevel.ERROR
|
||||
if shield.on_violation_action == OnViolationAction.RAISE
|
||||
else ViolationLevel.WARN
|
||||
),
|
||||
user_message=res.violation_return_message,
|
||||
metadata={
|
||||
"violation_type": res.violation_type,
|
||||
},
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
|
||||
if shield.type == ShieldType.llama_guard.value:
|
||||
cfg = self.config.llama_guard_shield
|
||||
return LlamaGuardShield(
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
)
|
||||
elif shield.type == ShieldType.prompt_guard.value:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
subtype = shield.params.get("prompt_guard_type", "injection")
|
||||
if subtype == "injection":
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif subtype == "jailbreak":
|
||||
return JailbreakShield.instance(model_dir)
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
||||
else:
|
||||
raise ValueError(f"Unknown shield type: {shield.type}")
|
||||
|
|
@ -156,7 +156,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
|
||||
messages = []
|
||||
if len(turns) == 0 and self.agent_config.instructions != "":
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
for i, turn in enumerate(turns):
|
||||
|
|
@ -641,12 +641,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if session_info.memory_bank_id is None:
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
memory_bank = VectorMemoryBankDef(
|
||||
identifier=bank_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
await self.memory_banks_api.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
)
|
||||
await self.memory_banks_api.register_memory_bank(memory_bank)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
|
@ -4,10 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||
persistence_store: KVStoreConfig
|
||||
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())
|
||||
|
|
@ -80,5 +80,5 @@ class AgentPersistence:
|
|||
except Exception as e:
|
||||
print(f"Error parsing turn: {e}")
|
||||
continue
|
||||
|
||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||
return turns
|
||||
|
|
@ -32,18 +32,18 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(
|
||||
self, messages: List[Message], shield_types: List[str]
|
||||
self, messages: List[Message], identifiers: List[str]
|
||||
) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
shield_type=shield_type,
|
||||
shield_id=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for shield_type in shield_types
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
for shield_type, response in zip(shield_types, responses):
|
||||
for identifier, response in zip(identifiers, responses):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
||||
|
|
@ -52,6 +52,6 @@ class ShieldRunnerMixin:
|
|||
raise SafetyException(violation)
|
||||
elif violation.violation_level == ViolationLevel.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield_type} raised a warning",
|
||||
f"[Warn]{identifier} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
|
|
@ -80,7 +80,7 @@ class MockInferenceAPI:
|
|||
|
||||
class MockSafetyAPI:
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message]
|
||||
self, shield_id: str, messages: List[Message]
|
||||
) -> RunShieldResponse:
|
||||
return RunShieldResponse(violation=None)
|
||||
|
||||
|
|
@ -9,8 +9,7 @@ from typing import List
|
|||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
|
||||
|
||||
from ..safety import ShieldRunnerMixin
|
||||
from .builtin import BaseTool
|
||||
|
||||
|
||||
|
|
@ -4,15 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import MetaReferenceDatasetIOConfig
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceDatasetIOConfig,
|
||||
config: LocalFSDatasetIOConfig,
|
||||
_deps,
|
||||
):
|
||||
from .datasetio import MetaReferenceDatasetIOImpl
|
||||
from .datasetio import LocalFSDatasetIOImpl
|
||||
|
||||
impl = MetaReferenceDatasetIOImpl(config)
|
||||
impl = LocalFSDatasetIOImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -6,4 +6,4 @@
|
|||
from llama_stack.apis.datasetio import * # noqa: F401, F403
|
||||
|
||||
|
||||
class MetaReferenceDatasetIOConfig(BaseModel): ...
|
||||
class LocalFSDatasetIOConfig(BaseModel): ...
|
||||
|
|
@ -3,22 +3,19 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import io
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import pandas
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import unquote
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
|
||||
from .config import MetaReferenceDatasetIOConfig
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
class BaseDataset(ABC):
|
||||
|
|
@ -40,12 +37,12 @@ class BaseDataset(ABC):
|
|||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
dataset_def: DatasetDef
|
||||
dataset_def: Dataset
|
||||
dataset_impl: BaseDataset
|
||||
|
||||
|
||||
class PandasDataframeDataset(BaseDataset):
|
||||
def __init__(self, dataset_def: DatasetDef, *args, **kwargs) -> None:
|
||||
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_def = dataset_def
|
||||
self.df = None
|
||||
|
|
@ -73,37 +70,15 @@ class PandasDataframeDataset(BaseDataset):
|
|||
if self.df is not None:
|
||||
return
|
||||
|
||||
# TODO: more robust support w/ data url
|
||||
if self.dataset_def.url.uri.endswith(".csv"):
|
||||
df = pandas.read_csv(self.dataset_def.url.uri)
|
||||
elif self.dataset_def.url.uri.endswith(".xlsx"):
|
||||
df = pandas.read_excel(self.dataset_def.url.uri)
|
||||
elif self.dataset_def.url.uri.startswith("data:"):
|
||||
parts = parse_data_url(self.dataset_def.url.uri)
|
||||
data = parts["data"]
|
||||
if parts["is_base64"]:
|
||||
data = base64.b64decode(data)
|
||||
else:
|
||||
data = unquote(data)
|
||||
encoding = parts["encoding"] or "utf-8"
|
||||
data = data.encode(encoding)
|
||||
|
||||
mime_type = parts["mimetype"]
|
||||
mime_category = mime_type.split("/")[0]
|
||||
data_bytes = io.BytesIO(data)
|
||||
|
||||
if mime_category == "text":
|
||||
df = pandas.read_csv(data_bytes)
|
||||
else:
|
||||
df = pandas.read_excel(data_bytes)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
|
||||
df = get_dataframe_from_url(self.dataset_def.url)
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
|
||||
self.df = self._validate_dataset_schema(df)
|
||||
|
||||
|
||||
class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
def __init__(self, config: MetaReferenceDatasetIOConfig) -> None:
|
||||
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
|
||||
self.config = config
|
||||
# local registry for keeping track of datasets within the provider
|
||||
self.dataset_infos = {}
|
||||
|
|
@ -114,17 +89,14 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_def: DatasetDef,
|
||||
dataset: Dataset,
|
||||
) -> None:
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
self.dataset_infos[dataset_def.identifier] = DatasetInfo(
|
||||
dataset_def=dataset_def,
|
||||
dataset_impl = PandasDataframeDataset(dataset)
|
||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
||||
dataset_def=dataset,
|
||||
dataset_impl=dataset_impl,
|
||||
)
|
||||
|
||||
async def list_datasets(self) -> List[DatasetDef]:
|
||||
return [i.dataset_def for i in self.dataset_infos.values()]
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
17
llama_stack/providers/inline/eval/meta_reference/config.py
Normal file
17
llama_stack/providers/inline/eval/meta_reference/config.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "meta_reference_eval.db").as_posix()
|
||||
) # Uses SQLite config specific to Meta Reference Eval storage
|
||||
|
|
@ -6,16 +6,22 @@
|
|||
from enum import Enum
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
EVAL_TASKS_PREFIX = "eval_tasks:"
|
||||
|
||||
|
||||
class ColumnName(Enum):
|
||||
input_query = "input_query"
|
||||
|
|
@ -25,7 +31,7 @@ class ColumnName(Enum):
|
|||
generated_answer = "generated_answer"
|
||||
|
||||
|
||||
class MetaReferenceEvalImpl(Eval):
|
||||
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceEvalConfig,
|
||||
|
|
@ -43,12 +49,32 @@ class MetaReferenceEvalImpl(Eval):
|
|||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
self.eval_tasks = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
# Load existing eval_tasks from kvstore
|
||||
start_key = EVAL_TASKS_PREFIX
|
||||
end_key = f"{EVAL_TASKS_PREFIX}\xff"
|
||||
stored_eval_tasks = await self.kvstore.range(start_key, end_key)
|
||||
|
||||
for eval_task in stored_eval_tasks:
|
||||
eval_task = EvalTask.model_validate_json(eval_task)
|
||||
self.eval_tasks[eval_task.identifier] = eval_task
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_eval_task(self, task_def: EvalTask) -> None:
|
||||
# Store in kvstore
|
||||
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=task_def.json(),
|
||||
)
|
||||
self.eval_tasks[task_def.identifier] = task_def
|
||||
|
||||
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
|
|
@ -70,21 +96,28 @@ class MetaReferenceEvalImpl(Eval):
|
|||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
|
||||
async def evaluate_batch(
|
||||
async def run_eval(
|
||||
self,
|
||||
dataset_id: str,
|
||||
candidate: EvalCandidate,
|
||||
scoring_functions: List[str],
|
||||
task_id: str,
|
||||
task_config: EvalTaskConfig,
|
||||
) -> Job:
|
||||
task_def = self.eval_tasks[task_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
candidate = task_config.eval_candidate
|
||||
scoring_functions = task_def.scoring_functions
|
||||
|
||||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
rows_in_page=(
|
||||
-1 if task_config.num_examples is None else task_config.num_examples
|
||||
),
|
||||
)
|
||||
res = await self.evaluate(
|
||||
res = await self.evaluate_rows(
|
||||
task_id=task_id,
|
||||
input_rows=all_rows.rows,
|
||||
candidate=candidate,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
|
|
@ -93,12 +126,14 @@ class MetaReferenceEvalImpl(Eval):
|
|||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id)
|
||||
|
||||
async def evaluate(
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
candidate: EvalCandidate,
|
||||
scoring_functions: List[str],
|
||||
task_config: EvalTaskConfig,
|
||||
) -> EvaluateResponse:
|
||||
candidate = task_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
raise NotImplementedError(
|
||||
"Evaluation with generation has not been implemented for agents"
|
||||
|
|
@ -108,7 +143,7 @@ class MetaReferenceEvalImpl(Eval):
|
|||
), "SamplingParams.max_tokens must be provided"
|
||||
|
||||
generations = []
|
||||
for x in input_rows:
|
||||
for x in tqdm(input_rows):
|
||||
if ColumnName.completion_input.value in x:
|
||||
input_content = eval(str(x[ColumnName.completion_input.value]))
|
||||
response = await self.inference_api.completion(
|
||||
|
|
@ -122,14 +157,17 @@ class MetaReferenceEvalImpl(Eval):
|
|||
}
|
||||
)
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
|
||||
chat_completion_input_str = str(
|
||||
x[ColumnName.chat_completion_input.value]
|
||||
)
|
||||
input_messages = eval(chat_completion_input_str)
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=candidate.model,
|
||||
model_id=candidate.model,
|
||||
messages=messages,
|
||||
sampling_params=candidate.sampling_params,
|
||||
)
|
||||
|
|
@ -147,23 +185,33 @@ class MetaReferenceEvalImpl(Eval):
|
|||
for input_r, generated_r in zip(input_rows, generations)
|
||||
]
|
||||
|
||||
if task_config.type == "app" and task_config.scoring_params is not None:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
||||
for scoring_fn_id in scoring_functions
|
||||
}
|
||||
else:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
||||
}
|
||||
|
||||
score_response = await self.scoring_api.score(
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||
)
|
||||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]:
|
||||
if job_id in self.jobs:
|
||||
return JobStatus.completed
|
||||
|
||||
return None
|
||||
|
||||
async def job_cancel(self, job_id: str) -> None:
|
||||
async def job_cancel(self, task_id: str, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(job_id)
|
||||
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(task_id, job_id)
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
|
|
@ -86,6 +86,7 @@ class Llama:
|
|||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
model = resolve_model(config.model)
|
||||
llama_model = model.core_model_id.value
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
|
@ -186,13 +187,20 @@ class Llama:
|
|||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
return Llama(model, tokenizer, model_args)
|
||||
return Llama(model, tokenizer, model_args, llama_model)
|
||||
|
||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
def __init__(
|
||||
self,
|
||||
model: Transformer,
|
||||
tokenizer: Tokenizer,
|
||||
args: ModelArgs,
|
||||
llama_model: str,
|
||||
):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.llama_model = llama_model
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
|
|
@ -369,7 +377,7 @@ class Llama:
|
|||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Generator:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, self.llama_model)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
|
|
@ -11,8 +11,15 @@ from typing import AsyncGenerator, List
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
|
|
@ -23,10 +30,19 @@ from .model_parallel import LlamaModelParallelGenerator
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
[
|
||||
build_model_alias(
|
||||
model.descriptor(),
|
||||
model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model = model
|
||||
|
|
@ -40,17 +56,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
else:
|
||||
self.generator = Llama.build(self.config)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Dynamic model registration is not supported")
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return [
|
||||
ModelDef(
|
||||
identifier=self.model.descriptor(),
|
||||
llama_model=self.model.descriptor(),
|
||||
)
|
||||
]
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
|
@ -66,9 +71,12 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
|
|
@ -79,7 +87,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
model=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
|
@ -87,6 +95,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
|
||||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
|
|
@ -185,7 +194,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
|
|
@ -200,7 +209,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
|
|
@ -211,6 +220,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
|
|
@ -384,7 +394,35 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def request_with_localized_media(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
) -> Union[ChatCompletionRequest, CompletionRequest]:
|
||||
if not request_has_media(request):
|
||||
return request
|
||||
|
||||
async def _convert_single_content(content):
|
||||
if isinstance(content, ImageMedia):
|
||||
url = await convert_image_media_to_url(content, download=True)
|
||||
return ImageMedia(image=URL(uri=url))
|
||||
else:
|
||||
return content
|
||||
|
||||
async def _convert_content(content):
|
||||
if isinstance(content, list):
|
||||
return [await _convert_single_content(c) for c in content]
|
||||
else:
|
||||
return await _convert_single_content(content)
|
||||
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
for m in request.messages:
|
||||
m.content = await _convert_content(m.content)
|
||||
else:
|
||||
request.content = await _convert_content(request.content)
|
||||
|
||||
return request
|
||||
|
|
@ -20,6 +20,7 @@ from llama_models.datatypes import CheckpointQuantizationFormat
|
|||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from termcolor import cprint
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
|
@ -27,9 +28,7 @@ from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
|||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue