Merge branch 'main' into clarifai-inference-provider

This commit is contained in:
sanjaychelliah 2024-11-26 18:01:45 +05:30 committed by GitHub
commit 4b9085d312
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
536 changed files with 34661 additions and 12116 deletions

View file

@ -6,7 +6,17 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from typing import (
Any,
AsyncIterator,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
@ -44,6 +54,7 @@ class ToolDefinitionCommon(BaseModel):
class SearchEngineType(Enum):
bing = "bing"
brave = "brave"
tavily = "tavily"
@json_schema_type
@ -396,6 +407,8 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
"""streamed agent turn completion response."""
event: AgentTurnResponseEvent
@ -404,6 +417,7 @@ class AgentStepResponse(BaseModel):
step: Step
@runtime_checkable
class Agents(Protocol):
@webmethod(route="/agents/create")
async def create_agent(
@ -424,18 +438,16 @@ class Agents(Protocol):
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AgentTurnResponseStreamChunk: ...
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")
async def get_agents_turn(
self,
agent_id: str,
turn_id: str,
self, agent_id: str, session_id: str, turn_id: str
) -> Turn: ...
@webmethod(route="/agents/step/get")
async def get_agents_step(
self, agent_id: str, turn_id: str, step_id: str
self, agent_id: str, session_id: str, turn_id: str, step_id: str
) -> AgentStepResponse: ...
@webmethod(route="/agents/session/create")

View file

@ -7,22 +7,26 @@
import asyncio
import json
import os
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional
import fire
import httpx
from dotenv import load_dotenv
from pydantic import BaseModel
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .agents import * # noqa: F403
import logging
from .event_logger import EventLogger
log = logging.getLogger(__name__)
load_dotenv()
@ -70,6 +74,14 @@ class AgentsClient(Agents):
async def create_agent_turn(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
return await self._nonstream_agent_turn(request)
async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
@ -85,13 +97,15 @@ class AgentsClient(Agents):
try:
jdata = json.loads(data)
if "error" in jdata:
cprint(data, "red")
log.error(data)
continue
yield AgentTurnResponseStreamChunk(**jdata)
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
log.error(f"Error with parsing or validation: {e}")
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
raise NotImplementedError("Non-streaming not implemented yet")
async def _run_agent(
@ -114,8 +128,8 @@ async def _run_agent(
)
for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agent_turn(
log.info(f"User> {content}", color="white", attrs=["bold"])
iterator = await api.create_agent_turn(
AgentTurnCreateRequest(
agent_id=create_response.agent_id,
session_id=session_response.session_id,
@ -127,13 +141,12 @@ async def _run_agent(
)
)
async for event, log in EventLogger().log(iterator):
if log is not None:
log.print()
async for event, logger in EventLogger().log(iterator):
if logger is not None:
log.info(logger)
async def run_llama_3_1(host: str, port: int):
model = "Llama3.1-8B-Instruct"
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [
@ -173,8 +186,7 @@ async def run_llama_3_1(host: str, port: int):
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
async def run_llama_3_2_rag(host: str, port: int):
model = "Llama3.2-3B-Instruct"
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
urls = [
@ -215,8 +227,7 @@ async def run_llama_3_2_rag(host: str, port: int):
)
async def run_llama_3_2(host: str, port: int):
model = "Llama3.2-3B-Instruct"
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
# zero shot tools for llama3.2 text models
@ -262,7 +273,7 @@ async def run_llama_3_2(host: str, port: int):
)
def main(host: str, port: int, run_type: str):
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
assert run_type in [
"tools_llama_3_1",
"tools_llama_3_2",
@ -274,7 +285,10 @@ def main(host: str, port: int, run_type: str):
"tools_llama_3_2": run_llama_3_2,
"rag_llama_3_2": run_llama_3_2_rag,
}
asyncio.run(fn[run_type](host, port))
args = [host, port]
if model is not None:
args.append(model)
asyncio.run(fn[run_type](*args))
if __name__ == "__main__":

View file

@ -180,5 +180,5 @@ class EventLogger:
color="cyan",
)
preivous_event_type = event_type
previous_event_type = event_type
previous_step_type = step_type

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
@ -47,8 +47,9 @@ class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
@runtime_checkable
class BatchInference(Protocol):
@webmethod(route="/batch_inference/completion")
@webmethod(route="/batch-inference/completion")
async def batch_completion(
self,
model: str,
@ -57,7 +58,7 @@ class BatchInference(Protocol):
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
@webmethod(route="/batch_inference/chat_completion")
@webmethod(route="/batch-inference/chat-completion")
async def batch_chat_completion(
self,
model: str,

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class Job(BaseModel):
job_id: str
@json_schema_type
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"

View file

@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class StringType(BaseModel):
type: Literal["string"] = "string"
class NumberType(BaseModel):
type: Literal["number"] = "number"
class BooleanType(BaseModel):
type: Literal["boolean"] = "boolean"
class ArrayType(BaseModel):
type: Literal["array"] = "array"
class ObjectType(BaseModel):
type: Literal["object"] = "object"
class JsonType(BaseModel):
type: Literal["json"] = "json"
class UnionType(BaseModel):
type: Literal["union"] = "union"
class ChatCompletionInputType(BaseModel):
# expects List[Message] for messages
type: Literal["chat_completion_input"] = "chat_completion_input"
class CompletionInputType(BaseModel):
# expects InterleavedTextMedia for content
type: Literal["completion_input"] = "completion_input"
class AgentTurnInputType(BaseModel):
# expects List[Message] for messages (may also include attachments?)
type: Literal["agent_turn_input"] = "agent_turn_input"
ParamType = Annotated[
Union[
StringType,
NumberType,
BooleanType,
ArrayType,
ObjectType,
JsonType,
UnionType,
ChatCompletionInputType,
CompletionInputType,
AgentTurnInputType,
],
Field(discriminator="type"),
]
# TODO: recursive definition of ParamType in these containers
# will cause infinite recursion in OpenAPI generation script
# since we are going with ChatCompletionInputType and CompletionInputType
# we don't need to worry about ArrayType/ObjectType/UnionType for now
# ArrayType.model_rebuild()
# ObjectType.model_rebuild()
# UnionType.model_rebuild()
# class CustomType(BaseModel):
# type: Literal["custom"] = "custom"
# validator_class: str

View file

@ -1,63 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional, Protocol
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@json_schema_type
class TrainEvalDatasetColumnType(Enum):
dialog = "dialog"
text = "text"
media = "media"
number = "number"
json = "json"
@json_schema_type
class TrainEvalDataset(BaseModel):
"""Dataset to be used for training or evaluating language models."""
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
columns: Dict[str, TrainEvalDatasetColumnType]
content_url: URL
metadata: Optional[Dict[str, Any]] = None
@json_schema_type
class CreateDatasetRequest(BaseModel):
"""Request to create a dataset."""
uuid: str
dataset: TrainEvalDataset
class Datasets(Protocol):
@webmethod(route="/datasets/create")
def create_dataset(
self,
uuid: str,
dataset: TrainEvalDataset,
) -> None: ...
@webmethod(route="/datasets/get")
def get_dataset(
self,
dataset_uuid: str,
) -> TrainEvalDataset: ...
@webmethod(route="/datasets/delete")
def delete_dataset(
self,
dataset_uuid: str,
) -> None: ...

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .reward_scoring import * # noqa: F401 F403
from .datasetio import * # noqa: F401 F403

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from pathlib import Path
from typing import Optional
import fire
import httpx
from termcolor import cprint
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasets.client import DatasetsClient
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
class DatasetIOClient(DatasetIO):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasetio/get_rows_paginated",
params={
"dataset_id": dataset_id,
"rows_in_page": rows_in_page,
"page_token": page_token,
"filter_condition": filter_condition,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return PaginatedRowsResult(**response.json())
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
# register dataset
test_file = (
Path(os.path.abspath(__file__)).parent.parent.parent
/ "providers/tests/datasetio/test_dataset.csv"
)
test_url = data_url_from_file(str(test_file))
response = await client.register_dataset(
DatasetDefWithProvider(
identifier="test-dataset",
provider_id="meta0",
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
)
# list datasets
list_dataset = await client.list_datasets()
cprint(list_dataset, "blue")
# datsetio client to get the rows
datasetio_client = DatasetIOClient(f"http://{host}:{port}")
response = await datasetio_client.get_rows_paginated(
dataset_id="test-dataset",
rows_in_page=4,
page_token=None,
filter_condition=None,
)
cprint(f"Returned {len(response.rows)} rows \n {response}", "green")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.datasets import * # noqa: F403
@json_schema_type
class PaginatedRowsResult(BaseModel):
# the rows obey the DatasetSchema for the given dataset
rows: List[Dict[str, Any]]
total_count: int
next_page_token: Optional[str] = None
class DatasetStore(Protocol):
def get_dataset(self, dataset_id: str) -> Dataset: ...
@runtime_checkable
class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/get-rows-paginated", method="GET")
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datasets import * # noqa: F401 F403

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import os
from pathlib import Path
from typing import Optional
import fire
import httpx
from termcolor import cprint
from .datasets import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
class DatasetsClient(Datasets):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_dataset(
self,
dataset_def: DatasetDefWithProvider,
) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/datasets/register",
json={
"dataset_def": json.loads(dataset_def.json()),
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
return
async def get_dataset(
self,
dataset_identifier: str,
) -> Optional[DatasetDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasets/get",
params={
"dataset_identifier": dataset_identifier,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return DatasetDefWithProvider(**response.json())
async def list_datasets(self) -> List[DatasetDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasets/list",
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return [DatasetDefWithProvider(**x) for x in response.json()]
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
# register dataset
test_file = (
Path(os.path.abspath(__file__)).parent.parent.parent
/ "providers/tests/datasetio/test_dataset.csv"
)
test_url = data_url_from_file(str(test_file))
response = await client.register_dataset(
DatasetDefWithProvider(
identifier="test-dataset",
provider_id="meta0",
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
)
# list datasets
list_dataset = await client.list_datasets()
cprint(list_dataset, "blue")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
class CommonDatasetFields(BaseModel):
dataset_schema: Dict[str, ParamType]
url: URL
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
)
@json_schema_type
class Dataset(CommonDatasetFields, Resource):
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
@property
def dataset_id(self) -> str:
return self.identifier
@property
def provider_dataset_id(self) -> str:
return self.provider_resource_id
class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str
provider_id: Optional[str] = None
provider_dataset_id: Optional[str] = None
class Datasets(Protocol):
@webmethod(route="/datasets/register", method="POST")
async def register_dataset(
self,
dataset_id: str,
dataset_schema: Dict[str, ParamType],
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
@webmethod(route="/datasets/get", method="GET")
async def get_dataset(
self,
dataset_id: str,
) -> Optional[Dataset]: ...
@webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[Dataset]: ...

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .evals import * # noqa: F401 F403
from .eval import * # noqa: F401 F403

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Optional, Protocol, Union
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
@json_schema_type
class ModelCandidate(BaseModel):
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
system_message: Optional[SystemMessage] = None
@json_schema_type
class AgentCandidate(BaseModel):
type: Literal["agent"] = "agent"
config: AgentConfig
EvalCandidate = Annotated[
Union[ModelCandidate, AgentCandidate], Field(discriminator="type")
]
@json_schema_type
class BenchmarkEvalTaskConfig(BaseModel):
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
@json_schema_type
class AppEvalTaskConfig(BaseModel):
type: Literal["app"] = "app"
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
# we could optinally add any specific dataset config here
EvalTaskConfig = Annotated[
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
]
@json_schema_type
class EvaluateResponse(BaseModel):
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
class Eval(Protocol):
@webmethod(route="/eval/run-eval", method="POST")
async def run_eval(
self,
task_id: str,
task_config: EvalTaskConfig,
) -> Job: ...
@webmethod(route="/eval/evaluate-rows", method="POST")
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse: ...
@webmethod(route="/eval/job/status", method="GET")
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/job/cancel", method="POST")
async def job_cancel(self, task_id: str, job_id: str) -> None: ...
@webmethod(route="/eval/job/result", method="GET")
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .eval_tasks import * # noqa: F401 F403

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
class CommonEvalTaskFields(BaseModel):
dataset_id: str
scoring_functions: List[str]
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
)
@json_schema_type
class EvalTask(CommonEvalTaskFields, Resource):
type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value
@property
def eval_task_id(self) -> str:
return self.identifier
@property
def provider_eval_task_id(self) -> str:
return self.provider_resource_id
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
eval_task_id: str
provider_id: Optional[str] = None
provider_eval_task_id: Optional[str] = None
@runtime_checkable
class EvalTasks(Protocol):
@webmethod(route="/eval-tasks/list", method="GET")
async def list_eval_tasks(self) -> List[EvalTask]: ...
@webmethod(route="/eval-tasks/get", method="GET")
async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
@webmethod(route="/eval-tasks/register", method="POST")
async def register_eval_task(
self,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_eval_task_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...

View file

@ -1,122 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Protocol
from llama_models.schema_utils import webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
class TextGenerationMetric(Enum):
perplexity = "perplexity"
rouge = "rouge"
bleu = "bleu"
class QuestionAnsweringMetric(Enum):
em = "em"
f1 = "f1"
class SummarizationMetric(Enum):
rouge = "rouge"
bleu = "bleu"
class EvaluationJob(BaseModel):
job_uuid: str
class EvaluationJobLogStream(BaseModel):
job_uuid: str
class EvaluateTaskRequestCommon(BaseModel):
job_uuid: str
dataset: TrainEvalDataset
checkpoint: Checkpoint
# generation params
sampling_params: SamplingParams = SamplingParams()
@json_schema_type
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate text generation."""
metrics: List[TextGenerationMetric]
@json_schema_type
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
"""Request to evaluate question answering."""
metrics: List[QuestionAnsweringMetric]
@json_schema_type
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate summarization."""
metrics: List[SummarizationMetric]
class EvaluationJobStatusResponse(BaseModel):
job_uuid: str
@json_schema_type
class EvaluationJobArtifactsResponse(BaseModel):
"""Artifacts of a evaluation job."""
job_uuid: str
class Evaluations(Protocol):
@webmethod(route="/evaluate/text_generation/")
def evaluate_text_generation(
self,
metrics: List[TextGenerationMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/question_answering/")
def evaluate_question_answering(
self,
metrics: List[QuestionAnsweringMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/summarization/")
def evaluate_summarization(
self,
metrics: List[SummarizationMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
@webmethod(route="/evaluate/job/status")
def get_evaluation_job_status(
self, job_uuid: str
) -> EvaluationJobStatusResponse: ...
# sends SSE stream of logs
@webmethod(route="/evaluate/job/logs")
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
@webmethod(route="/evaluate/job/cancel")
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
@webmethod(route="/evaluate/job/artifacts")
def get_evaluation_job_artifacts(
self, job_uuid: str
) -> EvaluationJobArtifactsResponse: ...

View file

@ -53,6 +53,7 @@ class InferenceClient(Inference):
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -63,9 +64,33 @@ class InferenceClient(Inference):
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)
response.raise_for_status()
j = response.json()
return ChatCompletionResponse(**j)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
@ -77,7 +102,8 @@ class InferenceClient(Inference):
if response.status_code != 200:
content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}", "red"
f"Error: HTTP {response.status_code} {content.decode()}",
"red",
)
return
@ -85,16 +111,11 @@ class InferenceClient(Inference):
if line.startswith("data:"):
data = line[len("data: ") :]
try:
if request.stream:
if "error" in data:
cprint(data, "red")
continue
if "error" in data:
cprint(data, "red")
continue
yield ChatCompletionResponseStreamChunk(
**json.loads(data)
)
else:
yield ChatCompletionResponse(**json.loads(data))
yield ChatCompletionResponseStreamChunk(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
@ -120,7 +141,8 @@ async def run_main(
else:
logprobs_config = None
iterator = client.chat_completion(
assert stream, "Non streaming not supported here"
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,
@ -150,7 +172,7 @@ async def run_mm_main(
],
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,

View file

@ -6,7 +6,15 @@
from enum import Enum
from typing import List, Literal, Optional, Protocol, Union
from typing import (
AsyncIterator,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
@ -14,6 +22,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
class LogProbConfig(BaseModel):
@ -24,6 +33,7 @@ class LogProbConfig(BaseModel):
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
int4 = "int4"
@json_schema_type
@ -36,8 +46,14 @@ class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
@json_schema_type
class Int4QuantizationConfig(BaseModel):
type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
Field(discriminator="type"),
]
@ -73,11 +89,35 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: Optional[StopReason] = None
class ResponseFormatType(Enum):
json_schema = "json_schema"
grammar = "grammar"
class JsonSchemaResponseFormat(BaseModel):
type: Literal[ResponseFormatType.json_schema.value] = (
ResponseFormatType.json_schema.value
)
json_schema: Dict[str, Any]
class GrammarResponseFormat(BaseModel):
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
bnf: Dict[str, Any]
ResponseFormat = Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
@json_schema_type
class CompletionRequest(BaseModel):
model: str
content: InterleavedTextMedia
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -87,7 +127,8 @@ class CompletionRequest(BaseModel):
class CompletionResponse(BaseModel):
"""Completion response."""
completion_message: CompletionMessage
content: str
stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None
@ -105,6 +146,7 @@ class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
logprobs: Optional[LogProbConfig] = None
@ -112,7 +154,7 @@ class BatchCompletionRequest(BaseModel):
class BatchCompletionResponse(BaseModel):
"""Batch completion response."""
completion_message_batch: List[CompletionMessage]
batch: List[CompletionResponse]
@json_schema_type
@ -127,6 +169,7 @@ class ChatCompletionRequest(BaseModel):
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -164,7 +207,7 @@ class BatchChatCompletionRequest(BaseModel):
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
batch: List[ChatCompletionResponse]
@json_schema_type
@ -172,34 +215,45 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class ModelStore(Protocol):
def get_model(self, identifier: str) -> Model: ...
@runtime_checkable
class Inference(Protocol):
model_store: ModelStore
@webmethod(route="/inference/completion")
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat_completion")
@webmethod(route="/inference/chat-completion")
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ...
@webmethod(route="/inference/embeddings")
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List, Protocol
from typing import Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@ -12,15 +12,15 @@ from pydantic import BaseModel
@json_schema_type
class ProviderInfo(BaseModel):
provider_id: str
provider_type: str
description: str
@json_schema_type
class RouteInfo(BaseModel):
route: str
method: str
providers: List[str]
provider_types: List[str]
@json_schema_type
@ -29,6 +29,7 @@ class HealthInfo(BaseModel):
# TODO: add a provider level status
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/providers/list", method="GET")
async def list_providers(self) -> Dict[str, ProviderInfo]: ...

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
import os
from pathlib import Path
@ -13,11 +12,11 @@ from typing import Any, Dict, List, Optional
import fire
import httpx
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks.client import MemoryBanksClient
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
@ -35,45 +34,6 @@ class MemoryClient(Memory):
async def shutdown(self) -> None:
pass
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client:
r = await client.get(
f"{self.base_url}/memory/get",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory/create",
json={
"name": name,
"config": config.dict(),
"url": url,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def insert_documents(
self,
bank_id: str,
@ -113,23 +73,28 @@ class MemoryClient(Memory):
async def run_main(host: str, port: int, stream: bool):
client = MemoryClient(f"http://{host}:{port}")
banks_client = MemoryBanksClient(f"http://{host}:{port}")
# create a memory bank
bank = await client.create_memory_bank(
name="test_bank",
config=VectorMemoryBankConfig(
bank_id="test_bank",
bank = VectorMemoryBank(
identifier="test_bank",
provider_id="",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await banks_client.register_memory_bank(
bank.identifier,
VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
provider_resource_id=bank.identifier,
)
cprint(json.dumps(bank.dict(), indent=4), "green")
retrieved_bank = await client.get_memory_bank(bank.bank_id)
retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
assert retrieved_bank is not None
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
assert retrieved_bank.embedding_model == "all-MiniLM-L6-v2"
urls = [
"memory_optimizations.rst",
@ -160,15 +125,17 @@ async def run_main(host: str, port: int, stream: bool):
for i, path in enumerate(files)
]
client = MemoryClient(f"http://{host}:{port}")
# insert some documents
await client.insert_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
documents=documents,
)
# query the documents
response = await client.query_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
query=[
"How do I use Lora?",
],
@ -178,7 +145,7 @@ async def run_main(host: str, port: int, stream: bool):
print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents(
bank_id=bank.bank_id,
bank_id=bank.identifier,
query=[
"Tell me more about llama3 and torchtune",
],

View file

@ -8,14 +8,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
@json_schema_type
@ -26,44 +26,6 @@ class MemoryBankDocument(BaseModel):
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class MemoryBankType(Enum):
vector = "vector"
keyvalue = "keyvalue"
keyword = "keyword"
graph = "graph"
class VectorMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
class KeyValueMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
class KeywordMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class GraphMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class Chunk(BaseModel):
content: InterleavedTextMedia
token_count: int
@ -76,45 +38,13 @@ class QueryDocumentsResponse(BaseModel):
scores: List[float]
@json_schema_type
class QueryAPI(Protocol):
@webmethod(route="/query_documents")
def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@json_schema_type
class MemoryBank(BaseModel):
bank_id: str
name: str
config: MemoryBankConfig
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
url: Optional[URL] = None
class MemoryBankStore(Protocol):
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@runtime_checkable
class Memory(Protocol):
@webmethod(route="/memory/create")
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank: ...
@webmethod(route="/memory/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory/get", method="GET")
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory/drop", method="DELETE")
async def drop_memory_bank(
self,
bank_id: str,
) -> str: ...
memory_bank_store: MemoryBankStore
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@ -126,13 +56,6 @@ class Memory(Protocol):
ttl_seconds: Optional[int] = None,
) -> None: ...
@webmethod(route="/memory/update")
async def update_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory/query")
async def query_documents(
self,
@ -140,17 +63,3 @@ class Memory(Protocol):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@webmethod(route="/memory/documents/get", method="GET")
async def get_documents(
self,
bank_id: str,
document_ids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory/documents/delete", method="DELETE")
async def delete_documents(
self,
bank_id: str,
document_ids: List[str],
) -> None: ...

View file

@ -6,7 +6,7 @@
import asyncio
from typing import List, Optional
from typing import Any, Dict, List, Optional
import fire
import httpx
@ -15,6 +15,27 @@ from termcolor import cprint
from .memory_banks import * # noqa: F403
def deserialize_memory_bank_def(
j: Optional[Dict[str, Any]]
) -> MemoryBankDefWithProvider:
if j is None:
return None
if "type" not in j:
raise ValueError("Memory bank type not specified")
type = j["type"]
if type == MemoryBankType.vector.value:
return VectorMemoryBank(**j)
elif type == MemoryBankType.keyvalue.value:
return KeyValueMemoryBank(**j)
elif type == MemoryBankType.keyword.value:
return KeywordMemoryBank(**j)
elif type == MemoryBankType.graph.value:
return GraphMemoryBank(**j)
else:
raise ValueError(f"Unknown memory bank type: {type}")
class MemoryBanksClient(MemoryBanks):
def __init__(self, base_url: str):
self.base_url = base_url
@ -25,37 +46,71 @@ class MemoryBanksClient(MemoryBanks):
async def shutdown(self) -> None:
pass
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
async def list_memory_banks(self) -> List[MemoryBank]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [MemoryBankSpec(**x) for x in response.json()]
return [deserialize_memory_bank_def(x) for x in response.json()]
async def get_serving_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]:
async def register_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_resource_id: Optional[str] = None,
provider_id: Optional[str] = None,
) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/memory_banks/register",
json={
"memory_bank_id": memory_bank_id,
"provider_resource_id": provider_resource_id,
"provider_id": provider_id,
"params": params.dict(),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def get_memory_bank(
self,
memory_bank_id: str,
) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_type": bank_type.value,
"memory_bank_id": memory_bank_id,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return MemoryBankSpec(**j)
return deserialize_memory_bank_def(j)
async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}")
response = await client.list_available_memory_banks()
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")
# register memory bank for the first time
response = await client.register_memory_bank(
memory_bank_id="test_bank2",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
cprint(f"register_memory_bank response={response}", "blue")
# list again after registering
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")

View file

@ -4,29 +4,146 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from enum import Enum
from typing import (
Annotated,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig
from llama_stack.apis.resource import Resource, ResourceType
@json_schema_type
class MemoryBankSpec(BaseModel):
bank_type: MemoryBankType
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_type, and corresponding config. ",
class MemoryBankType(Enum):
vector = "vector"
keyvalue = "keyvalue"
keyword = "keyword"
graph = "graph"
# define params for each type of memory bank, this leads to a tagged union
# accepted as input from the API or from the config.
@json_schema_type
class VectorMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
@json_schema_type
class KeyValueMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
MemoryBankType.keyvalue.value
)
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
async def list_available_memory_banks(self) -> List[MemoryBankSpec]: ...
@json_schema_type
class KeywordMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
MemoryBankType.keyword.value
)
@webmethod(route="/memory_banks/get", method="GET")
async def get_serving_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]: ...
@json_schema_type
class GraphMemoryBankParams(BaseModel):
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
BankParams = Annotated[
Union[
VectorMemoryBankParams,
KeyValueMemoryBankParams,
KeywordMemoryBankParams,
GraphMemoryBankParams,
],
Field(discriminator="memory_bank_type"),
]
# Some common functionality for memory banks.
class MemoryBankResourceMixin(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
@property
def memory_bank_id(self) -> str:
return self.identifier
@property
def provider_memory_bank_id(self) -> str:
return self.provider_resource_id
@json_schema_type
class VectorMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
@json_schema_type
class KeyValueMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
MemoryBankType.keyvalue.value
)
# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
@json_schema_type
class KeywordMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
MemoryBankType.keyword.value
)
@json_schema_type
class GraphMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBank = Annotated[
Union[
VectorMemoryBank,
KeyValueMemoryBank,
KeywordMemoryBank,
GraphMemoryBank,
],
Field(discriminator="memory_bank_type"),
]
class MemoryBankInput(BaseModel):
memory_bank_id: str
params: BankParams
provider_memory_bank_id: Optional[str] = None
@runtime_checkable
class MemoryBanks(Protocol):
@webmethod(route="/memory-banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory-banks/get", method="GET")
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory-banks/register", method="POST")
async def register_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank: ...
@webmethod(route="/memory-banks/unregister", method="POST")
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import json
from typing import List, Optional
@ -25,21 +26,32 @@ class ModelsClient(Models):
async def shutdown(self) -> None:
pass
async def list_models(self) -> List[ModelServingSpec]:
async def list_models(self) -> List[Model]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/models/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [ModelServingSpec(**x) for x in response.json()]
return [Model(**x) for x in response.json()]
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
async def register_model(self, model: Model) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/models/register",
json={
"model": json.loads(model.model_dump_json()),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def get_model(self, identifier: str) -> Optional[Model]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/models/get",
params={
"core_model_id": core_model_id,
"identifier": identifier,
},
headers={"Content-Type": "application/json"},
)
@ -47,7 +59,16 @@ class ModelsClient(Models):
j = response.json()
if j is None:
return None
return ModelServingSpec(**j)
return Model(**j)
async def unregister_model(self, model_id: str) -> None:
async with httpx.AsyncClient() as client:
response = await client.delete(
f"{self.base_url}/models/delete",
params={"model_id": model_id},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def run_main(host: str, port: int, stream: bool):

View file

@ -4,29 +4,60 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from llama_models.llama3.api.datatypes import Model
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
from llama_stack.apis.resource import Resource, ResourceType
class CommonModelFields(BaseModel):
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this model",
)
@json_schema_type
class ModelServingSpec(BaseModel):
llama_model: Model = Field(
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).",
)
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_type, and corresponding config. ",
)
class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value
@property
def model_id(self) -> str:
return self.identifier
@property
def provider_model_id(self) -> str:
return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=())
class ModelInput(CommonModelFields):
model_id: str
provider_id: Optional[str] = None
provider_model_id: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
@runtime_checkable
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> List[ModelServingSpec]: ...
async def list_models(self) -> List[Model]: ...
@webmethod(route="/models/get", method="GET")
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...
async def get_model(self, identifier: str) -> Optional[Model]: ...
@webmethod(route="/models/register", method="POST")
async def register_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Model: ...
@webmethod(route="/models/unregister", method="POST")
async def unregister_model(self, model_id: str) -> None: ...

View file

@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
@ -107,8 +107,8 @@ class PostTrainingSFTRequest(BaseModel):
job_uuid: str
model: str
dataset: TrainEvalDataset
validation_dataset: TrainEvalDataset
dataset_id: str
validation_dataset_id: str
algorithm: FinetuningAlgorithm
algorithm_config: Union[
@ -131,8 +131,8 @@ class PostTrainingRLHFRequest(BaseModel):
finetuned_model: URL
dataset: TrainEvalDataset
validation_dataset: TrainEvalDataset
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: Union[DPOAlignmentConfig]
@ -176,13 +176,13 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol):
@webmethod(route="/post_training/supervised_fine_tune")
@webmethod(route="/post-training/supervised-fine-tune")
def supervised_fine_tune(
self,
job_uuid: str,
model: str,
dataset: TrainEvalDataset,
validation_dataset: TrainEvalDataset,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
@ -193,13 +193,13 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post_training/preference_optimize")
@webmethod(route="/post-training/preference-optimize")
def preference_optimize(
self,
job_uuid: str,
finetuned_model: URL,
dataset: TrainEvalDataset,
validation_dataset: TrainEvalDataset,
dataset_id: str,
validation_dataset_id: str,
algorithm: RLHFAlgorithm,
algorithm_config: Union[DPOAlignmentConfig],
optimizer_config: OptimizerConfig,
@ -208,22 +208,22 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post_training/jobs")
@webmethod(route="/post-training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post_training/job/logs")
@webmethod(route="/post-training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
@webmethod(route="/post_training/job/status")
@webmethod(route="/post-training/job/status")
def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post_training/job/cancel")
@webmethod(route="/post-training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post_training/job/artifacts")
@webmethod(route="/post-training/job/artifacts")
def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...

View file

@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class ResourceType(Enum):
model = "model"
shield = "shield"
memory_bank = "memory_bank"
dataset = "dataset"
scoring_function = "scoring_function"
eval_task = "eval_task"
class Resource(BaseModel):
"""Base class for all Llama Stack resources"""
identifier: str = Field(
description="Unique identifier for this resource in llama stack"
)
provider_resource_id: str = Field(
description="Unique identifier for this resource in the provider",
default=None,
)
provider_id: str = Field(description="ID of the provider that owns this resource")
type: ResourceType = Field(
description="Type of resource (e.g. 'model', 'shield', 'memory_bank', etc.)"
)

View file

@ -1,55 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
@json_schema_type
class ScoredMessage(BaseModel):
message: Message
score: float
@json_schema_type
class DialogGenerations(BaseModel):
dialog: List[Message]
sampled_generations: List[Message]
@json_schema_type
class ScoredDialogGenerations(BaseModel):
dialog: List[Message]
scored_generations: List[ScoredMessage]
@json_schema_type
class RewardScoringRequest(BaseModel):
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
dialog_generations: List[DialogGenerations]
model: str
@json_schema_type
class RewardScoringResponse(BaseModel):
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
scored_generations: List[ScoredDialogGenerations]
class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score")
def reward_score(
self,
dialog_generations: List[DialogGenerations],
model: str,
) -> Union[RewardScoringResponse]: ...

View file

@ -27,7 +27,7 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
def encodable_dict(d: BaseModel):
return json.loads(d.json())
return json.loads(d.model_dump_json())
class SafetyClient(Safety):
@ -41,13 +41,13 @@ class SafetyClient(Safety):
pass
async def run_shield(
self, shield_type: str, messages: List[Message]
self, shield_id: str, messages: List[Message]
) -> RunShieldResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/safety/run_shield",
json=dict(
shield_type=shield_type,
shield_id=shield_id,
messages=[encodable_dict(m) for m in messages],
),
headers={
@ -80,7 +80,7 @@ async def run_main(host: str, port: int, image_path: str = None):
)
cprint(f"User>{message.content}", "green")
response = await client.run_shield(
shield_type="llama_guard",
shield_id="Llama-Guard-3-1B",
messages=[message],
)
print(response)
@ -91,13 +91,7 @@ async def run_main(host: str, port: int, image_path: str = None):
]:
cprint(f"User>{message.content}", "green")
response = await client.run_shield(
shield_type="llama_guard",
messages=[message],
)
print(response)
response = await client.run_shield(
shield_type="injection_shield",
shield_id="llama_guard",
messages=[message],
)
print(response)

View file

@ -5,12 +5,13 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Protocol
from typing import Any, Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
@json_schema_type
@ -37,8 +38,18 @@ class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None
class ShieldStore(Protocol):
async def get_shield(self, identifier: str) -> Shield: ...
@runtime_checkable
class Safety(Protocol):
@webmethod(route="/safety/run_shield")
shield_store: ShieldStore
@webmethod(route="/safety/run-shield")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse: ...

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .dataset import * # noqa: F401 F403
from .scoring import * # noqa: F401 F403

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from pathlib import Path
import fire
import httpx
from termcolor import cprint
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio.client import DatasetIOClient
from llama_stack.apis.datasets.client import DatasetsClient
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
class ScoringClient(Scoring):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
) -> ScoreBatchResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/scoring/score_batch",
json={
"dataset_id": dataset_id,
"scoring_functions": scoring_functions,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return ScoreBatchResponse(**response.json())
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/scoring/score",
json={
"input_rows": input_rows,
"scoring_functions": scoring_functions,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return ScoreResponse(**response.json())
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
# register dataset
test_file = (
Path(os.path.abspath(__file__)).parent.parent.parent
/ "providers/tests/datasetio/test_dataset.csv"
)
test_url = data_url_from_file(str(test_file))
response = await client.register_dataset(
DatasetDefWithProvider(
identifier="test-dataset",
provider_id="meta0",
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
)
# list datasets
list_dataset = await client.list_datasets()
cprint(list_dataset, "blue")
# datsetio client to get the rows
datasetio_client = DatasetIOClient(f"http://{host}:{port}")
response = await datasetio_client.get_rows_paginated(
dataset_id="test-dataset",
rows_in_page=4,
page_token=None,
filter_condition=None,
)
cprint(f"Returned {len(response.rows)} rows \n {response}", "green")
# scoring client to score the rows
scoring_client = ScoringClient(f"http://{host}:{port}")
response = await scoring_client.score(
input_rows=response.rows,
scoring_functions=["equality"],
)
cprint(f"score response={response}", "blue")
# test scoring batch using datasetio api
scoring_client = ScoringClient(f"http://{host}:{port}")
response = await scoring_client.score_batch(
dataset_id="test-dataset",
scoring_functions=["equality"],
)
cprint(f"score_batch response={response}", "cyan")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
# mapping of metric to value
ScoringResultRow = Dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
score_rows: List[ScoringResultRow]
# aggregated metrics to value
aggregated_results: Dict[str, Any]
@json_schema_type
class ScoreBatchResponse(BaseModel):
dataset_id: Optional[str] = None
results: Dict[str, ScoringResult]
@json_schema_type
class ScoreResponse(BaseModel):
# each key in the dict is a scoring function name
results: Dict[str, ScoringResult]
class ScoringFunctionStore(Protocol):
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
@runtime_checkable
class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score-batch")
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score")
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring_functions import * # noqa: F401 F403

View file

@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@json_schema_type
class ScoringFnParamsType(Enum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = (
ScoringFnParamsType.llm_as_judge.value
)
judge_model: str
prompt_template: Optional[str] = None
judge_score_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = (
ScoringFnParamsType.regex_parser.value
)
parsing_regexes: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
)
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
],
Field(discriminator="type"),
]
class CommonScoringFnFields(BaseModel):
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
params: Optional[ScoringFnParams] = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = (
ResourceType.scoring_function.value
)
@property
def scoring_fn_id(self) -> str:
return self.identifier
@property
def provider_scoring_fn_id(self) -> str:
return self.provider_resource_id
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: Optional[str] = None
provider_scoring_fn_id: Optional[str] = None
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions/list", method="GET")
async def list_scoring_functions(self) -> List[ScoringFn]: ...
@webmethod(route="/scoring-functions/get", method="GET")
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions/register", method="POST")
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None: ...

View file

@ -25,21 +25,41 @@ class ShieldsClient(Shields):
async def shutdown(self) -> None:
pass
async def list_shields(self) -> List[ShieldSpec]:
async def list_shields(self) -> List[Shield]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [ShieldSpec(**x) for x in response.json()]
return [Shield(**x) for x in response.json()]
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str],
provider_id: Optional[str],
params: Optional[Dict[str, Any]],
) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/shields/register",
json={
"shield_id": shield_id,
"provider_shield_id": provider_shield_id,
"provider_id": provider_id,
"params": params,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def get_shield(self, shield_id: str) -> Optional[Shield]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/get",
params={
"shield_type": shield_type,
"shield_id": shield_id,
},
headers={"Content-Type": "application/json"},
)
@ -49,7 +69,7 @@ class ShieldsClient(Shields):
if j is None:
return None
return ShieldSpec(**j)
return Shield(**j)
async def run_main(host: str, port: int, stream: bool):

View file

@ -4,25 +4,52 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from pydantic import BaseModel
from llama_stack.distribution.datatypes import GenericProviderConfig
from llama_stack.apis.resource import Resource, ResourceType
class CommonShieldFields(BaseModel):
params: Optional[Dict[str, Any]] = None
@json_schema_type
class ShieldSpec(BaseModel):
shield_type: str
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_type, and corresponding config. ",
)
class Shield(CommonShieldFields, Resource):
"""A safety shield resource that can be used to check content"""
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
@property
def shield_id(self) -> str:
return self.identifier
@property
def provider_shield_id(self) -> str:
return self.provider_resource_id
class ShieldInput(CommonShieldFields):
shield_id: str
provider_id: Optional[str] = None
provider_shield_id: Optional[str] = None
@runtime_checkable
class Shields(Protocol):
@webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[ShieldSpec]: ...
async def list_shields(self) -> List[Shield]: ...
@webmethod(route="/shields/get", method="GET")
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ...
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
@webmethod(route="/shields/register", method="POST")
async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield: ...

View file

@ -13,7 +13,6 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.reward_scoring import * # noqa: F403
class FilteringFunction(Enum):
@ -40,12 +39,12 @@ class SyntheticDataGenerationRequest(BaseModel):
class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
synthetic_data: List[ScoredDialogGenerations]
synthetic_data: List[Dict[str, Any]]
statistics: Optional[Dict[str, Any]] = None
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate")
@webmethod(route="/synthetic-data-generation/generate")
def synthetic_data_generate(
self,
dialogs: List[Message],

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Literal, Optional, Protocol, Union
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@ -123,9 +123,10 @@ Event = Annotated[
]
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/log_event")
@webmethod(route="/telemetry/log-event")
async def log_event(self, event: Event) -> None: ...
@webmethod(route="/telemetry/get_trace", method="GET")
@webmethod(route="/telemetry/get-trace", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_STACK_API_VERSION = "alpha"

View file

@ -9,15 +9,27 @@ import asyncio
import json
import os
import shutil
import time
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional
import httpx
from pydantic import BaseModel
from llama_models.datatypes import Model
from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict
from rich.console import Console
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
TextColumn,
TimeRemainingColumn,
TransferSpeedColumn,
)
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
@ -61,6 +73,13 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
required=False,
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
)
parser.add_argument(
"--max-parallel",
type=int,
required=False,
default=3,
help="Maximum number of concurrent downloads",
)
parser.add_argument(
"--ignore-patterns",
type=str,
@ -80,6 +99,245 @@ safetensors files to avoid downloading duplicate weights.
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
@dataclass
class DownloadTask:
url: str
output_file: str
total_size: int = 0
downloaded_size: int = 0
task_id: Optional[int] = None
retries: int = 0
max_retries: int = 3
class DownloadError(Exception):
pass
class CustomTransferSpeedColumn(TransferSpeedColumn):
def render(self, task):
if task.finished:
return "-"
return super().render(task)
class ParallelDownloader:
def __init__(
self,
max_concurrent_downloads: int = 3,
buffer_size: int = 1024 * 1024,
timeout: int = 30,
):
self.max_concurrent_downloads = max_concurrent_downloads
self.buffer_size = buffer_size
self.timeout = timeout
self.console = Console()
self.progress = Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(bar_width=40),
"[progress.percentage]{task.percentage:>3.1f}%",
DownloadColumn(),
CustomTransferSpeedColumn(),
TimeRemainingColumn(),
console=self.console,
expand=True,
)
self.client_options = {
"timeout": httpx.Timeout(timeout),
"follow_redirects": True,
}
async def retry_with_exponential_backoff(
self, task: DownloadTask, func, *args, **kwargs
):
last_exception = None
for attempt in range(task.max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < task.max_retries - 1:
wait_time = min(30, 2**attempt) # Cap at 30 seconds
self.console.print(
f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, "
f"retrying in {wait_time} seconds: {str(e)}[/yellow]"
)
await asyncio.sleep(wait_time)
continue
raise last_exception
async def get_file_info(
self, client: httpx.AsyncClient, task: DownloadTask
) -> None:
async def _get_info():
response = await client.head(
task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
)
response.raise_for_status()
return response
try:
response = await self.retry_with_exponential_backoff(task, _get_info)
task.url = str(response.url)
task.total_size = int(response.headers.get("Content-Length", 0))
if task.total_size == 0:
raise DownloadError(
f"Unable to determine file size for {task.output_file}. "
"The server might not support range requests."
)
# Update the progress bar's total size once we know it
if task.task_id is not None:
self.progress.update(task.task_id, total=task.total_size)
except httpx.HTTPError as e:
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
raise
def verify_file_integrity(self, task: DownloadTask) -> bool:
if not os.path.exists(task.output_file):
return False
return os.path.getsize(task.output_file) == task.total_size
async def download_chunk(
self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
) -> None:
async def _download_chunk():
headers = {"Range": f"bytes={start}-{end}"}
async with client.stream(
"GET", task.url, headers=headers, **self.client_options
) as response:
response.raise_for_status()
with open(task.output_file, "ab") as file:
file.seek(start)
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
task.downloaded_size += len(chunk)
self.progress.update(
task.task_id,
completed=task.downloaded_size,
)
try:
await self.retry_with_exponential_backoff(task, _download_chunk)
except Exception as e:
raise DownloadError(
f"Failed to download chunk {start}-{end} after "
f"{task.max_retries} attempts: {str(e)}"
) from e
async def prepare_download(self, task: DownloadTask) -> None:
output_dir = os.path.dirname(task.output_file)
os.makedirs(output_dir, exist_ok=True)
if os.path.exists(task.output_file):
task.downloaded_size = os.path.getsize(task.output_file)
async def download_file(self, task: DownloadTask) -> None:
try:
async with httpx.AsyncClient(**self.client_options) as client:
await self.get_file_info(client, task)
# Check if file is already downloaded
if os.path.exists(task.output_file):
if self.verify_file_integrity(task):
self.console.print(
f"[green]Already downloaded {task.output_file}[/green]"
)
self.progress.update(task.task_id, completed=task.total_size)
return
await self.prepare_download(task)
try:
# Split the remaining download into chunks
chunk_size = 27_000_000_000 # Cloudfront max chunk size
chunks = []
current_pos = task.downloaded_size
while current_pos < task.total_size:
chunk_end = min(
current_pos + chunk_size - 1, task.total_size - 1
)
chunks.append((current_pos, chunk_end))
current_pos = chunk_end + 1
# Download chunks in sequence
for chunk_start, chunk_end in chunks:
await self.download_chunk(client, task, chunk_start, chunk_end)
except Exception as e:
raise DownloadError(f"Download failed: {str(e)}") from e
except Exception as e:
self.progress.update(
task.task_id, description=f"[red]Failed: {task.output_file}[/red]"
)
raise DownloadError(
f"Download failed for {task.output_file}: {str(e)}"
) from e
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
try:
total_remaining_size = sum(
task.total_size - task.downloaded_size for task in tasks
)
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
free_space = shutil.disk_usage(dir_path).free
# Add 10% buffer for safety
required_space = int(total_remaining_size * 1.1)
if free_space < required_space:
self.console.print(
f"[red]Not enough disk space. Required: {required_space // (1024 * 1024)} MB, "
f"Available: {free_space // (1024 * 1024)} MB[/red]"
)
return False
return True
except Exception as e:
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
async def download_all(self, tasks: List[DownloadTask]) -> None:
if not tasks:
raise ValueError("No download tasks provided")
if not self.has_disk_space(tasks):
raise DownloadError("Insufficient disk space for downloads")
failed_tasks = []
with self.progress:
for task in tasks:
desc = f"Downloading {Path(task.output_file).name}"
task.task_id = self.progress.add_task(
desc, total=task.total_size, completed=task.downloaded_size
)
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
async def download_with_semaphore(task: DownloadTask):
async with semaphore:
try:
await self.download_file(task)
except Exception as e:
failed_tasks.append((task, str(e)))
await asyncio.gather(*(download_with_semaphore(task) for task in tasks))
if failed_tasks:
self.console.print("\n[red]Some downloads failed:[/red]")
for task, error in failed_tasks:
self.console.print(
f"[red]- {Path(task.output_file).name}: {error}[/red]"
)
raise DownloadError(f"{len(failed_tasks)} downloads failed")
def _hf_download(
model: "Model",
hf_token: str,
@ -120,67 +378,50 @@ def _hf_download(
print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
def _meta_download(
model: "Model",
model_id: str,
meta_url: str,
info: "LlamaDownloadInfo",
max_concurrent_downloads: int,
):
from llama_stack.distribution.utils.model_utils import model_local_dir
output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True)
# I believe we can use some concurrency here if needed but not sure it is worth it
# Create download tasks for each file
tasks = []
for f in info.files:
output_file = str(output_dir / f)
url = meta_url.replace("*", f"{info.folder}/{f}")
total_size = info.pth_size if "consolidated" in f else 0
cprint(f"Downloading `{f}`...", "white")
downloader = ResumableDownloader(url, output_file, total_size)
asyncio.run(downloader.download())
print(f"\nSuccessfully downloaded model to {output_dir}")
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_models.sku_list import llama_meta_net_info, resolve_model
from .model.safety_models import prompt_guard_download_info, prompt_guard_model_sku
if args.manifest_file:
_download_from_manifest(args.manifest_file)
return
if args.model_id is None:
parser.error("Please provide a model id")
return
prompt_guard = prompt_guard_model_sku()
if args.model_id == prompt_guard.model_id:
model = prompt_guard
info = prompt_guard_download_info()
else:
model = resolve_model(args.model_id)
if model is None:
parser.error(f"Model {args.model_id} not found")
return
info = llama_meta_net_info(model)
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url
if not meta_url:
meta_url = input(
"Please provide the signed URL you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
tasks.append(
DownloadTask(
url=url, output_file=output_file, total_size=total_size, max_retries=3
)
assert meta_url is not None and "llamameta.net" in meta_url
_meta_download(model, meta_url, info)
)
# Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
cprint(
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
"white",
)
cprint(
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
"yellow",
)
class ModelEntry(BaseModel):
model_id: str
files: Dict[str, str]
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class Manifest(BaseModel):
@ -188,7 +429,7 @@ class Manifest(BaseModel):
expires_on: datetime
def _download_from_manifest(manifest_file: str):
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
from llama_stack.distribution.utils.model_utils import model_local_dir
with open(manifest_file, "r") as f:
@ -198,143 +439,88 @@ def _download_from_manifest(manifest_file: str):
if datetime.now() > manifest.expires_on:
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console()
for entry in manifest.models:
print(f"Downloading model {entry.model_id}...")
console.print(f"[blue]Downloading model {entry.model_id}...[/blue]")
output_dir = Path(model_local_dir(entry.model_id))
os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()):
cprint(f"Output directory {output_dir} is not empty.", "red")
console.print(
f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
)
while True:
resp = input(
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
)
if resp.lower() == "restart" or resp.lower() == "r":
if resp.lower() in ["restart", "r"]:
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
break
elif resp.lower() == "continue" or resp.lower() == "c":
print("Continuing download...")
elif resp.lower() in ["continue", "c"]:
console.print("[blue]Continuing download...[/blue]")
break
else:
cprint("Invalid response. Please try again.", "red")
console.print("[red]Invalid response. Please try again.[/red]")
for fname, url in entry.files.items():
output_file = str(output_dir / fname)
downloader = ResumableDownloader(url, output_file)
asyncio.run(downloader.download())
# Create download tasks for all files in the manifest
tasks = [
DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3)
for fname, url in entry.files.items()
]
# Initialize and run parallel downloader
downloader = ParallelDownloader(
max_concurrent_downloads=max_concurrent_downloads
)
asyncio.run(downloader.download_all(tasks))
class ResumableDownloader:
def __init__(
self,
url: str,
output_file: str,
total_size: int = 0,
buffer_size: int = 32 * 1024,
):
self.url = url
self.output_file = output_file
self.buffer_size = buffer_size
self.total_size = total_size
self.downloaded_size = 0
self.start_size = 0
self.start_time = 0
async def get_file_info(self, client: httpx.AsyncClient) -> None:
if self.total_size > 0:
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
"""Main download command handler"""
try:
if args.manifest_file:
_download_from_manifest(args.manifest_file, args.max_parallel)
return
# Force disable compression when trying to retrieve file size
response = await client.head(
self.url, follow_redirects=True, headers={"Accept-Encoding": "identity"}
)
response.raise_for_status()
self.url = str(response.url) # Update URL in case of redirects
self.total_size = int(response.headers.get("Content-Length", 0))
if self.total_size == 0:
raise ValueError(
"Unable to determine file size. The server might not support range requests."
)
if args.model_id is None:
parser.error("Please provide a model id")
return
async def download(self) -> None:
self.start_time = time.time()
async with httpx.AsyncClient(follow_redirects=True) as client:
await self.get_file_info(client)
# Handle comma-separated model IDs
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
if os.path.exists(self.output_file):
self.downloaded_size = os.path.getsize(self.output_file)
self.start_size = self.downloaded_size
if self.downloaded_size >= self.total_size:
print(f"Already downloaded `{self.output_file}`, skipping...")
return
from llama_models.sku_list import llama_meta_net_info, resolve_model
additional_size = self.total_size - self.downloaded_size
if not self.has_disk_space(additional_size):
M = 1024 * 1024 # noqa
print(
f"Not enough disk space to download `{self.output_file}`. "
f"Required: {(additional_size // M):.2f} MB"
)
raise ValueError(
f"Not enough disk space to download `{self.output_file}`"
)
while True:
if self.downloaded_size >= self.total_size:
break
# Cloudfront has a max-size limit
max_chunk_size = 27_000_000_000
request_size = min(
self.total_size - self.downloaded_size, max_chunk_size
)
headers = {
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
}
print(f"Downloading `{self.output_file}`....{headers}")
try:
async with client.stream(
"GET", self.url, headers=headers
) as response:
response.raise_for_status()
with open(self.output_file, "ab") as file:
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
self.downloaded_size += len(chunk)
self.print_progress()
except httpx.HTTPError as e:
print(f"\nDownload interrupted: {e}")
print("You can resume the download by running the script again.")
except Exception as e:
print(f"\nAn error occurred: {e}")
print(f"\nFinished downloading `{self.output_file}`....")
def print_progress(self) -> None:
percent = (self.downloaded_size / self.total_size) * 100
bar_length = 50
filled_length = int(bar_length * self.downloaded_size // self.total_size)
bar = "" * filled_length + "-" * (bar_length - filled_length)
elapsed_time = time.time() - self.start_time
M = 1024 * 1024 # noqa
speed = (
(self.downloaded_size - self.start_size) / (elapsed_time * M)
if elapsed_time > 0
else 0
)
print(
f"\rProgress: |{bar}| {percent:.2f}% "
f"({self.downloaded_size // M}/{self.total_size // M} MB) "
f"Speed: {speed:.2f} MiB/s",
end="",
flush=True,
from .model.safety_models import (
prompt_guard_download_info,
prompt_guard_model_sku,
)
def has_disk_space(self, file_size: int) -> bool:
dir_path = os.path.dirname(os.path.abspath(self.output_file))
free_space = shutil.disk_usage(dir_path).free
return free_space > file_size
prompt_guard = prompt_guard_model_sku()
for model_id in model_ids:
if model_id == prompt_guard.model_id:
model = prompt_guard
info = prompt_guard_download_info()
else:
model = resolve_model(model_id)
if model is None:
parser.error(f"Model {model_id} not found")
continue
info = llama_meta_net_info(model)
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url or input(
f"Please provide the signed URL for model {model_id} you received via email "
f"after visiting https://www.llama.com/llama-downloads/ "
f"(e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
if "llamameta.net" not in meta_url:
parser.error("Invalid Meta URL provided")
_meta_download(model, model_id, meta_url, info, args.max_parallel)
except Exception as e:
parser.error(f"Download failed: {str(e)}")

View file

@ -9,6 +9,7 @@ import argparse
from .download import Download
from .model import ModelParser
from .stack import StackParser
from .verify_download import VerifyDownload
class LlamaCLIParser:
@ -27,9 +28,10 @@ class LlamaCLIParser:
subparsers = self.parser.add_subparsers(title="subcommands")
# Add sub-commands
Download.create(subparsers)
ModelParser.create(subparsers)
StackParser.create(subparsers)
Download.create(subparsers)
VerifyDownload.create(subparsers)
def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()

View file

@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe
from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.subcommand import Subcommand
@ -32,3 +33,4 @@ class ModelParser(Subcommand):
ModelList.create(subparsers)
ModelPromptFormat.create(subparsers)
ModelDescribe.create(subparsers)
ModelVerifyDownload.create(subparsers)

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class ModelVerifyDownload(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama model verify-download",
description="Verify the downloaded checkpoints' checksums",
formatter_class=argparse.RawTextHelpFormatter,
)
from llama_stack.cli.verify_download import setup_verify_download_parser
setup_verify_download_parser(self.parser)

View file

@ -9,12 +9,17 @@ import argparse
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import * # noqa: F403
import os
import shutil
from functools import lru_cache
from pathlib import Path
TEMPLATES_PATH = (
Path(os.path.relpath(__file__)).parent.parent.parent / "distribution" / "templates"
)
import pkg_resources
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.dynamic import instantiate_class_type
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
@lru_cache()
@ -22,11 +27,10 @@ def available_templates_specs() -> List[BuildConfig]:
import yaml
template_specs = []
for p in TEMPLATES_PATH.rglob("*.yaml"):
for p in TEMPLATES_PATH.rglob("*build.yaml"):
with open(p, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f))
template_specs.append(build_config)
return template_specs
@ -65,174 +69,57 @@ class StackBuild(Subcommand):
help="Show the available templates for building a Llama Stack distribution",
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the Llama Stack build to override from template config. This name will be used as paths to store configuration files, build conda environments/docker images. If not specified, will use the name from the template config. ",
)
self.parser.add_argument(
"--image-type",
type=str,
help="Image Type to use for the build. This can be either conda or docker. If not specified, will use the image type from the template config.",
choices=["conda", "docker"],
)
def _get_build_config_from_name(self, args: argparse.Namespace) -> Optional[Path]:
if os.getenv("CONDA_PREFIX", ""):
conda_dir = (
Path(os.getenv("CONDA_PREFIX")).parent / f"llamastack-{args.name}"
)
else:
cprint(
"Cannot find CONDA_PREFIX. Trying default conda path ~/.conda/envs...",
color="green",
)
conda_dir = (
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.name}"
)
build_config_file = Path(conda_dir) / f"{args.name}-build.yaml"
if build_config_file.exists():
return build_config_file
return None
def _run_stack_build_command_from_build_config(
self, build_config: BuildConfig
) -> None:
import json
import os
import yaml
from llama_stack.distribution.build import ApiInput, build_image, ImageType
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.serialize import EnumEncoder
from termcolor import cprint
# save build.yaml spec for building same distribution again
if build_config.image_type == ImageType.docker.value:
# docker needs build file to be in the llama-stack repo dir to be able to copy over to the image
llama_stack_path = Path(
os.path.abspath(__file__)
).parent.parent.parent.parent
build_dir = llama_stack_path / "tmp/configs/"
else:
build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}"
os.makedirs(build_dir, exist_ok=True)
build_file_path = build_dir / f"{build_config.name}-build.yaml"
with open(build_file_path, "w") as f:
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False))
return_code = build_image(build_config, build_file_path)
if return_code != 0:
return
configure_name = (
build_config.name
if build_config.image_type == "conda"
else (f"llamastack-{build_config.name}")
)
if build_config.image_type == "conda":
cprint(
f"You can now run `llama stack configure {configure_name}`",
color="green",
)
else:
cprint(
f"You can now run `llama stack run {build_config.name}`",
color="green",
)
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
import json
import yaml
from llama_stack.cli.table import print_table
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"Template Name",
"Providers",
"Description",
]
rows = []
for spec in available_templates_specs():
rows.append(
[
spec.name,
json.dumps(spec.distribution_spec.providers, indent=2),
spec.distribution_spec.description,
]
)
print_table(
rows,
headers,
separate_rows=True,
default="conda",
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import textwrap
import yaml
from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.validation import Validator
from termcolor import cprint
from llama_stack.distribution.distribution import get_provider_registry
if args.list_templates:
self._run_template_list_cmd(args)
return
if args.template:
if not args.name:
self.parser.error(
"You must specify a name for the build using --name when using a template"
)
return
build_path = TEMPLATES_PATH / f"{args.template}-build.yaml"
if not build_path.exists():
self.parser.error(
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates"
)
return
with open(build_path, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f))
build_config.name = args.name
if args.image_type:
build_config.image_type = args.image_type
self._run_stack_build_command_from_build_config(build_config)
return
# try to see if we can find a pre-existing build config file through name
if args.name:
maybe_build_config = self._get_build_config_from_name(args)
if maybe_build_config:
cprint(
f"Building from existing build config for {args.name} in {str(maybe_build_config)}...",
"green",
)
with open(maybe_build_config, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f))
self._run_stack_build_command_from_build_config(build_config)
available_templates = available_templates_specs()
for build_config in available_templates:
if build_config.name == args.template:
if args.image_type:
build_config.image_type = args.image_type
else:
self.parser.error(
f"Please specify a image-type (docker | conda) for {args.template}"
)
self._run_stack_build_command_from_build_config(
build_config, template_name=args.template
)
return
self.parser.error(
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates"
)
return
if not args.config and not args.template:
if not args.name:
name = prompt(
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
validator=Validator.from_callable(
lambda x: len(x) > 0,
error_message="Name cannot be empty, please enter a name",
),
)
else:
name = args.name
name = prompt(
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
validator=Validator.from_callable(
lambda x: len(x) > 0,
error_message="Name cannot be empty, please enter a name",
),
)
image_type = prompt(
"> Enter the image type you want your Llama Stack to be built as (docker or conda): ",
@ -244,26 +131,31 @@ class StackBuild(Subcommand):
)
cprint(
"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs.
""",
),
color="green",
)
print("Tip: use <TAB> to see options for the providers.\n")
providers = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [
x
for x in providers_for_api.keys()
if x not in ("remote", "remote::sample")
]
api_provider = prompt(
"> Enter provider for the {} API: (default=meta-reference): ".format(
api.value
),
"> Enter provider for API {}: ".format(api.value),
completer=WordCompleter(available_providers),
complete_while_typing=True,
validator=Validator.from_callable(
lambda x: x in providers_for_api,
error_message="Invalid provider, please enter one of the following: {}".format(
list(providers_for_api.keys())
),
),
default=(
"meta-reference"
if "meta-reference" in providers_for_api
else list(providers_for_api.keys())[0]
lambda x: x in available_providers,
error_message="Invalid provider, use <TAB> to see options",
),
)
@ -292,3 +184,153 @@ class StackBuild(Subcommand):
self.parser.error(f"Could not parse config file {args.config}: {e}")
return
self._run_stack_build_command_from_build_config(build_config)
def _generate_run_config(self, build_config: BuildConfig, build_dir: Path) -> None:
"""
Generate a run.yaml template file for user to edit from a build.yaml file
"""
import json
import yaml
from termcolor import cprint
from llama_stack.distribution.build import ImageType
apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig(
docker_image=(
build_config.name
if build_config.image_type == ImageType.docker.value
else None
),
image_name=build_config.name,
conda_env=(
build_config.name
if build_config.image_type == ImageType.conda.value
else None
),
apis=apis,
providers={},
)
# build providers dict
provider_registry = get_provider_registry()
for api in apis:
run_config.providers[api] = []
provider_types = build_config.distribution_spec.providers[api]
if isinstance(provider_types, str):
provider_types = [provider_types]
for i, provider_type in enumerate(provider_types):
pid = provider_type.split("::")[-1]
p = provider_registry[Api(api)][provider_type]
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
config_type = instantiate_class_type(
provider_registry[Api(api)][provider_type].config_class
)
if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(
__distro_dir__=f"distributions/{build_config.name}"
)
else:
config = {}
p_spec = Provider(
provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid,
provider_type=provider_type,
config=config,
)
run_config.providers[api].append(p_spec)
os.makedirs(build_dir, exist_ok=True)
run_config_file = build_dir / f"{build_config.name}-run.yaml"
with open(run_config_file, "w") as f:
to_write = json.loads(run_config.model_dump_json())
f.write(yaml.dump(to_write, sort_keys=False))
cprint(
f"You can now edit {run_config_file} and run `llama stack run {run_config_file}`",
color="green",
)
def _run_stack_build_command_from_build_config(
self, build_config: BuildConfig, template_name: Optional[str] = None
) -> None:
import json
import os
import re
import yaml
from termcolor import cprint
from llama_stack.distribution.build import build_image
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
# save build.yaml spec for building same distribution again
build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}"
os.makedirs(build_dir, exist_ok=True)
build_file_path = build_dir / f"{build_config.name}-build.yaml"
with open(build_file_path, "w") as f:
to_write = json.loads(build_config.model_dump_json())
f.write(yaml.dump(to_write, sort_keys=False))
return_code = build_image(build_config, build_file_path)
if return_code != 0:
return
if template_name:
# copy run.yaml from template to build_dir instead of generating it again
template_path = pkg_resources.resource_filename(
"llama_stack", f"templates/{template_name}/run.yaml"
)
os.makedirs(build_dir, exist_ok=True)
run_config_file = build_dir / f"{build_config.name}-run.yaml"
shutil.copy(template_path, run_config_file)
with open(template_path, "r") as f:
yaml_content = f.read()
# Find all ${env.VARIABLE} patterns
env_vars = set(re.findall(r"\${env\.([A-Za-z0-9_]+)}", yaml_content))
cprint("Build Successful! Next steps: ", color="green")
cprint(
f" 1. Set the environment variables: {list(env_vars)}",
color="green",
)
cprint(
f" 2. Run: `llama stack run {template_name}`",
color="green",
)
else:
self._generate_run_config(build_config, build_dir)
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
import json
from llama_stack.cli.table import print_table
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"Template Name",
"Providers",
"Description",
]
rows = []
for spec in available_templates_specs():
rows.append(
[
spec.name,
json.dumps(spec.distribution_spec.providers, indent=2),
spec.distribution_spec.description,
]
)
print_table(
rows,
headers,
separate_rows=True,
)

View file

@ -7,8 +7,6 @@
import argparse
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.datatypes import * # noqa: F403
class StackConfigure(Subcommand):
@ -39,138 +37,10 @@ class StackConfigure(Subcommand):
)
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
import json
import os
import subprocess
from pathlib import Path
import pkg_resources
import yaml
from termcolor import cprint
from llama_stack.distribution.build import ImageType
from llama_stack.distribution.utils.exec import run_with_pty
docker_image = None
build_config_file = Path(args.config)
if build_config_file.exists():
with open(build_config_file, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f))
self._configure_llama_distribution(build_config, args.output_dir)
return
# if we get here, we need to try to find the conda build config file
cprint(
f"Could not find {build_config_file}. Trying conda build name instead...",
color="green",
)
conda_dir = (
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
)
output = subprocess.check_output(
["bash", "-c", "conda info --json -a"]
)
conda_envs = json.loads(output.decode("utf-8"))["envs"]
for x in conda_envs:
if x.endswith(f"/llamastack-{args.config}"):
conda_dir = Path(x)
break
build_config_file = Path(conda_dir) / f"{args.config}-build.yaml"
if build_config_file.exists():
with open(build_config_file, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f))
self._configure_llama_distribution(build_config, args.output_dir)
return
# if we get here, we need to try to find the docker image
cprint(
f"Could not find {build_config_file}. Trying docker image name instead...",
color="green",
)
docker_image = args.config
builds_dir = BUILDS_BASE_DIR / ImageType.docker.value
if args.output_dir:
builds_dir = Path(output_dir)
os.makedirs(builds_dir, exist_ok=True)
script = pkg_resources.resource_filename(
"llama_stack", "distribution/configure_container.sh"
)
script_args = [script, docker_image, str(builds_dir)]
return_code = run_with_pty(script_args)
# we have regenerated the build config file with script, now check if it exists
if return_code != 0:
self.parser.error(
f"Failed to configure container {docker_image} with return code {return_code}. Please run `llama stack build` first. "
)
return
return
def _configure_llama_distribution(
self,
build_config: BuildConfig,
output_dir: Optional[str] = None,
):
import json
import os
from pathlib import Path
import yaml
from termcolor import cprint
from llama_stack.distribution.configure import configure_api_providers
from llama_stack.distribution.utils.serialize import EnumEncoder
builds_dir = BUILDS_BASE_DIR / build_config.image_type
if output_dir:
builds_dir = Path(output_dir)
os.makedirs(builds_dir, exist_ok=True)
image_name = build_config.name.replace("::", "-")
run_config_file = builds_dir / f"{image_name}-run.yaml"
if run_config_file.exists():
cprint(
f"Configuration already exists at `{str(run_config_file)}`. Will overwrite...",
"yellow",
attrs=["bold"],
)
config = StackRunConfig(**yaml.safe_load(run_config_file.read_text()))
else:
config = StackRunConfig(
built_at=datetime.now(),
image_name=image_name,
apis_to_serve=[],
api_providers={},
)
config = configure_api_providers(config, build_config.distribution_spec)
config.docker_image = (
image_name if build_config.image_type == "docker" else None
)
config.conda_env = image_name if build_config.image_type == "conda" else None
with open(run_config_file, "w") as f:
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False))
cprint(
f"> YAML configuration has been written to `{run_config_file}`.",
color="blue",
)
cprint(
f"You can now run `llama stack run {image_name} --port PORT`",
color="green",
self.parser.error(
"""
DEPRECATED! llama stack configure has been deprecated.
Please use llama stack run <path/to/run.yaml> instead.
Please see example run.yaml in /distributions folder.
"""
)

View file

@ -5,9 +5,11 @@
# the root directory of this source tree.
import argparse
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import * # noqa: F403
REPO_ROOT = Path(__file__).parent.parent.parent.parent
class StackRun(Subcommand):
@ -40,16 +42,24 @@ class StackRun(Subcommand):
help="Disable IPv6 support",
default=False,
)
self.parser.add_argument(
"--env",
action="append",
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
default=[],
metavar="KEY=VALUE",
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
from pathlib import Path
import pkg_resources
import yaml
from llama_stack.distribution.build import ImageType
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import (
BUILDS_BASE_DIR,
DISTRIBS_BASE_DIR,
)
from llama_stack.distribution.utils.exec import run_with_pty
if not args.config:
@ -57,26 +67,43 @@ class StackRun(Subcommand):
return
config_file = Path(args.config)
if not config_file.exists() and not args.config.endswith(".yaml"):
has_yaml_suffix = args.config.endswith(".yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = (
Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
)
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to conda dir
config_file = Path(
BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml"
)
if not config_file.exists() and not args.config.endswith(".yaml"):
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to docker dir
config_file = Path(
BUILDS_BASE_DIR / ImageType.docker.value / f"{args.config}-run.yaml"
)
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(
DISTRIBS_BASE_DIR
/ f"llamastack-{args.config}"
/ f"{args.config}-run.yaml"
)
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist. Please run `llama stack build` and `llama stack configure <name>` to generate a run.yaml file"
f"File {str(config_file)} does not exist. Please run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
return
with open(config_file, "r") as f:
config = StackRunConfig(**yaml.safe_load(f))
print(f"Using config file: {config_file}")
config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if config.docker_image:
script = pkg_resources.resource_filename(
@ -98,4 +125,16 @@ class StackRun(Subcommand):
if args.disable_ipv6:
run_args.append("--disable-ipv6")
for env_var in args.env:
if "=" not in env_var:
self.parser.error(
f"Environment variable '{env_var}' must be in KEY=VALUE format"
)
return
key, value = env_var.split("=", 1) # split on first = only
if not key:
self.parser.error(f"Environment variable '{env_var}' has empty key")
return
run_args.extend(["--env", f"{key}={value}"])
run_with_pty(run_args)

View file

@ -1,105 +0,0 @@
from argparse import Namespace
from unittest.mock import MagicMock, patch
import pytest
from llama_stack.distribution.datatypes import BuildConfig
from llama_stack.cli.stack.build import StackBuild
# temporary while we make the tests work
pytest.skip(allow_module_level=True)
@pytest.fixture
def stack_build():
parser = MagicMock()
subparsers = MagicMock()
return StackBuild(subparsers)
def test_stack_build_initialization(stack_build):
assert stack_build.parser is not None
assert stack_build.parser.set_defaults.called_once_with(
func=stack_build._run_stack_build_command
)
@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_with_config(
mock_build_image, mock_build_config, stack_build
):
args = Namespace(
config="test_config.yaml",
template=None,
list_templates=False,
name=None,
image_type="conda",
)
with patch("builtins.open", MagicMock()):
with patch("yaml.safe_load") as mock_yaml_load:
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
mock_build_config.return_value = MagicMock()
stack_build._run_stack_build_command(args)
mock_build_config.assert_called_once()
mock_build_image.assert_called_once()
@patch("llama_stack.cli.table.print_table")
def test_run_stack_build_command_list_templates(mock_print_table, stack_build):
args = Namespace(list_templates=True)
stack_build._run_stack_build_command(args)
mock_print_table.assert_called_once()
@patch("prompt_toolkit.prompt")
@patch("llama_stack.distribution.datatypes.BuildConfig")
@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_interactive(
mock_build_image, mock_build_config, mock_prompt, stack_build
):
args = Namespace(
config=None, template=None, list_templates=False, name=None, image_type=None
)
mock_prompt.side_effect = [
"test_name",
"conda",
"meta-reference",
"test description",
]
mock_build_config.return_value = MagicMock()
stack_build._run_stack_build_command(args)
assert mock_prompt.call_count == 4
mock_build_config.assert_called_once()
mock_build_image.assert_called_once()
@patch("llama_stack.distribution.datatypes.BuildConfig")
@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_with_template(
mock_build_image, mock_build_config, stack_build
):
args = Namespace(
config=None,
template="test_template",
list_templates=False,
name="test_name",
image_type="docker",
)
with patch("builtins.open", MagicMock()):
with patch("yaml.safe_load") as mock_yaml_load:
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
mock_build_config.return_value = MagicMock()
stack_build._run_stack_build_command(args)
mock_build_config.assert_called_once()
mock_build_image.assert_called_once()

View file

@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
import pytest
import yaml
from llama_stack.distribution.configure import (
LLAMA_STACK_RUN_CONFIG_VERSION,
parse_and_maybe_upgrade_config,
)
@pytest.fixture
def up_to_date_config():
return yaml.safe_load(
"""
version: {version}
image_name: foo
apis_to_serve: []
built_at: {built_at}
providers:
inference:
- provider_id: provider1
provider_type: inline::meta-reference
config: {{}}
safety:
- provider_id: provider1
provider_type: inline::meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
enable_prompt_guard: false
memory:
- provider_id: provider1
provider_type: inline::meta-reference
config: {{}}
""".format(
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
)
)
@pytest.fixture
def old_config():
return yaml.safe_load(
"""
image_name: foo
built_at: {built_at}
apis_to_serve: []
routing_table:
inference:
- provider_type: remote::ollama
config:
host: localhost
port: 11434
routing_key: Llama3.2-1B-Instruct
- provider_type: inline::meta-reference
config:
model: Llama3.1-8B-Instruct
routing_key: Llama3.1-8B-Instruct
safety:
- routing_key: ["shield1", "shield2"]
provider_type: inline::meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
enable_prompt_guard: false
memory:
- routing_key: vector
provider_type: inline::meta-reference
config: {{}}
api_providers:
telemetry:
provider_type: noop
config: {{}}
""".format(
built_at=datetime.now().isoformat()
)
)
@pytest.fixture
def invalid_config():
return yaml.safe_load(
"""
routing_table: {}
api_providers: {}
"""
)
def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
result = parse_and_maybe_upgrade_config(up_to_date_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert "inference" in result.providers
def test_parse_and_maybe_upgrade_config_old_format(old_config):
result = parse_and_maybe_upgrade_config(old_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert all(
api in result.providers
for api in ["inference", "safety", "memory", "telemetry"]
)
safety_provider = result.providers["safety"][0]
assert safety_provider.provider_type == "meta-reference"
assert "llama_guard_shield" in safety_provider.config
inference_providers = result.providers["inference"]
assert len(inference_providers) == 2
assert set(x.provider_id for x in inference_providers) == {
"remote::ollama-00",
"meta-reference-01",
}
ollama = inference_providers[0]
assert ollama.provider_type == "remote::ollama"
assert ollama.config["port"] == 11434
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
with pytest.raises(ValueError):
parse_and_maybe_upgrade_config(invalid_config)

View file

@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import hashlib
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from llama_stack.cli.subcommand import Subcommand
@dataclass
class VerificationResult:
filename: str
expected_hash: str
actual_hash: Optional[str]
exists: bool
matches: bool
class VerifyDownload(Subcommand):
"""Llama cli for verifying downloaded model files"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama verify-download",
description="Verify integrity of downloaded model files",
formatter_class=argparse.RawTextHelpFormatter,
)
setup_verify_download_parser(self.parser)
def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--model-id",
required=True,
help="Model ID to verify",
)
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
md5_hash = hashlib.md5()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
md5_hash.update(chunk)
return md5_hash.hexdigest()
def load_checksums(checklist_path: Path) -> Dict[str, str]:
checksums = {}
with open(checklist_path, "r") as f:
for line in f:
if line.strip():
md5sum, filepath = line.strip().split(" ", 1)
# Remove leading './' if present
filepath = filepath.lstrip("./")
checksums[filepath] = md5sum
return checksums
def verify_files(
model_dir: Path, checksums: Dict[str, str], console: Console
) -> List[VerificationResult]:
results = []
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
for filepath, expected_hash in checksums.items():
full_path = model_dir / filepath
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
exists = full_path.exists()
actual_hash = None
matches = False
if exists:
actual_hash = calculate_md5(full_path)
matches = actual_hash == expected_hash
results.append(
VerificationResult(
filename=filepath,
expected_hash=expected_hash,
actual_hash=actual_hash,
exists=exists,
matches=matches,
)
)
progress.remove_task(task_id)
return results
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_stack.distribution.utils.model_utils import model_local_dir
console = Console()
model_dir = Path(model_local_dir(args.model_id))
checklist_path = model_dir / "checklist.chk"
if not model_dir.exists():
parser.error(f"Model directory not found: {model_dir}")
if not checklist_path.exists():
parser.error(f"Checklist file not found: {checklist_path}")
checksums = load_checksums(checklist_path)
results = verify_files(model_dir, checksums, console)
# Print results
console.print("\nVerification Results:")
all_good = True
for result in results:
if not result.exists:
console.print(f"[red]❌ {result.filename}: File not found[/red]")
all_good = False
elif not result.matches:
console.print(
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
f" Expected: {result.expected_hash}\n"
f" Got: {result.actual_hash}"
)
all_good = False
else:
console.print(f"[green]✓ {result.filename}: Verified[/green]")
if all_good:
console.print("\n[green]All files verified successfully![/green]")

View file

@ -4,26 +4,29 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from enum import Enum
from typing import List, Optional
from typing import List
import pkg_resources
from llama_stack.distribution.utils.exec import run_with_pty
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
log = logging.getLogger(__name__)
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"aiosqlite",
"fastapi",
"fire",
"httpx",
@ -36,28 +39,19 @@ class ImageType(Enum):
conda = "conda"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
class ApiInput(BaseModel):
api: Api
provider: str
def build_image(build_config: BuildConfig, build_file_path: Path):
package_deps = Dependencies(
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES,
)
# extend package dependencies based on providers spec
def get_provider_dependencies(
config_providers: Dict[str, List[Provider]]
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
all_providers = get_provider_registry()
for (
api_str,
provider_or_providers,
) in build_config.distribution_spec.providers.items():
deps = []
for api_str, provider_or_providers in config_providers.items():
providers_for_api = all_providers[Api(api_str)]
providers = (
@ -67,25 +61,50 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
)
for provider in providers:
if provider not in providers_for_api:
# Providers from BuildConfig and RunConfig are subtly different  not great
provider_type = (
provider if isinstance(provider, str) else provider.provider_type
)
if provider_type not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)
provider_spec = providers_for_api[provider]
package_deps.pip_packages.extend(provider_spec.pip_packages)
provider_spec = providers_for_api[provider_type]
deps.extend(provider_spec.pip_packages)
if provider_spec.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
normal_deps = []
special_deps = []
deps = []
for package in package_deps.pip_packages:
for package in deps:
if "--no-deps" in package or "--index-url" in package:
special_deps.append(package)
else:
deps.append(package)
deps = list(set(deps))
special_deps = list(set(special_deps))
normal_deps.append(package)
return list(set(normal_deps)), list(set(special_deps))
def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers)
print(
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
)
for special_dep in special_deps:
log.info(f"\tpip install {special_dep}")
print()
def build_image(build_config: BuildConfig, build_file_path: Path):
docker_image = build_config.distribution_spec.docker_image or "python:3.10-slim"
normal_deps, special_deps = get_provider_dependencies(
build_config.distribution_spec.providers
)
normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename(
@ -94,10 +113,10 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
args = [
script,
build_config.name,
package_deps.docker_image,
docker_image,
str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.docker.value),
" ".join(deps),
" ".join(normal_deps),
]
else:
script = pkg_resources.resource_filename(
@ -107,7 +126,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
script,
build_config.name,
str(build_file_path),
" ".join(deps),
" ".join(normal_deps),
]
if special_deps:
@ -115,9 +134,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
return_code = run_with_pty(args)
if return_code != 0:
cprint(
log.error(
f"Failed to build target {build_config.name} with return code {return_code}",
color="red",
)
return return_code

View file

@ -1,8 +1,15 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
BUILD_PLATFORM=${BUILD_PLATFORM:-}
if [ "$#" -lt 4 ]; then
echo "Usage: $0 <build_name> <docker_base> <pip_dependencies> [<special_pip_deps>]" >&2
@ -15,7 +22,7 @@ special_pip_deps="$6"
set -euo pipefail
build_name="$1"
image_name="llamastack-$build_name"
image_name="distribution-$build_name"
docker_base=$2
build_file_path=$3
host_build_dir=$4
@ -30,13 +37,9 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
DOCKER_BINARY=${DOCKER_BINARY:-docker}
DOCKER_OPTS=${DOCKER_OPTS:-}
REPO_CONFIGS_DIR="$REPO_DIR/tmp/configs"
TEMP_DIR=$(mktemp -d)
llama stack configure $build_file_path
cp $host_build_dir/$build_name-run.yaml $REPO_CONFIGS_DIR
add_to_docker() {
local input
output_file="$TEMP_DIR/Dockerfile"
@ -62,6 +65,19 @@ RUN apt-get update && apt-get install -y \
EOF
# Add pip dependencies first since llama-stack is what will change most often
# so we can reuse layers.
if [ -n "$pip_dependencies" ]; then
add_to_docker "RUN pip install --no-cache $pip_dependencies"
fi
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
for part in "${parts[@]}"; do
add_to_docker "RUN pip install --no-cache $part"
done
fi
stack_mount="/app/llama-stack-source"
models_mount="/app/llama-models-source"
@ -74,9 +90,18 @@ if [ -n "$LLAMA_STACK_DIR" ]; then
# Install in editable format. We will mount the source code into the container
# so that changes will be reflected in the container without having to do a
# rebuild. This is just for development convenience.
add_to_docker "RUN pip install -e $stack_mount"
add_to_docker "RUN pip install --no-cache -e $stack_mount"
else
add_to_docker "RUN pip install llama-stack"
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
add_to_docker "RUN pip install fastapi libcst"
add_to_docker <<EOF
RUN pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
EOF
else
add_to_docker "RUN pip install --no-cache llama-stack"
fi
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
@ -87,34 +112,20 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
add_to_docker <<EOF
RUN pip uninstall -y llama-models
RUN pip install $models_mount
RUN pip install --no-cache $models_mount
EOF
fi
if [ -n "$pip_dependencies" ]; then
add_to_docker "RUN pip install $pip_dependencies"
fi
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<< "$special_pip_deps"
for part in "${parts[@]}"; do
add_to_docker "RUN pip install $part"
done
fi
add_to_docker <<EOF
# This would be good in production but for debugging flexibility lets not add it right now
# We need a more solid production ready entrypoint.sh anyway
#
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$build_name"]
EOF
add_to_docker "ADD tmp/configs/$(basename "$build_file_path") ./llamastack-build.yaml"
add_to_docker "ADD tmp/configs/$build_name-run.yaml ./llamastack-run.yaml"
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
cat $TEMP_DIR/Dockerfile
printf "\n"
@ -127,16 +138,41 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
fi
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
DOCKER_OPTS="$DOCKER_OPTS --security-opt label=disable"
fi
# Set version tag based on PyPI version
if [ -n "$TEST_PYPI_VERSION" ]; then
version_tag="test-$TEST_PYPI_VERSION"
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_MODELS_DIR" ]]; then
version_tag="dev"
else
URL="https://pypi.org/pypi/llama-stack/json"
version_tag=$(curl -s $URL | jq -r '.info.version')
fi
# Add version tag to image name
image_tag="$image_name:$version_tag"
# Detect platform architecture
ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then
PLATFORM="--platform $BUILD_PLATFORM"
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
PLATFORM="--platform linux/arm64"
elif [ "$ARCH" = "x86_64" ]; then
PLATFORM="--platform linux/amd64"
else
echo "Unsupported architecture: $ARCH"
exit 1
fi
set -x
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
$DOCKER_BINARY build $DOCKER_OPTS $PLATFORM -t $image_tag -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
# clean up tmp/configs
rm -rf $REPO_CONFIGS_DIR
set +x
echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name"
echo "Success!"

View file

@ -0,0 +1,226 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import json
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any, get_args, get_origin, Type, Union
import httpx
from pydantic import BaseModel, parse_obj_as
from termcolor import cprint
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
client_class = create_api_client_class(protocol)
impl = client_class(config.url)
await impl.initialize()
return impl
def create_api_client_class(protocol) -> Type:
if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol]
class APIClient:
def __init__(self, base_url: str):
print(f"({protocol.__name__}) Connecting to {base_url}")
self.base_url = base_url.rstrip("/")
self.routes = {}
# Store routes for this protocol
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
async def initialize(self):
pass
async def shutdown(self):
pass
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
# TODO: make this more precise, same thing needs to happen in server.py
is_streaming = kwargs.get("stream", False)
if is_streaming:
return self._call_streaming(method_name, *args, **kwargs)
else:
return await self._call_non_streaming(method_name, *args, **kwargs)
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
_, sig = self.routes[method_name]
if sig.return_annotation is None:
return_type = None
else:
return_type = extract_non_async_iterator_type(sig.return_annotation)
assert (
return_type
), f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
response = await client.request(**params)
response.raise_for_status()
j = response.json()
if j is None:
return None
# print(f"({protocol.__name__}) Returning {j}, type {return_type}")
return parse_obj_as(return_type, j)
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
webmethod, sig = self.routes[method_name]
return_type = extract_async_iterator_type(sig.return_annotation)
assert (
return_type
), f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
async with client.stream(**params) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line[len("data: ") :]
try:
data = json.loads(data)
if "error" in data:
cprint(data, "red")
continue
yield parse_obj_as(return_type, data)
except Exception as e:
print(f"Error with parsing or validation: {e}")
print(data)
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
webmethod, sig = self.routes[method_name]
parameters = list(sig.parameters.values())[1:] # skip `self`
for i, param in enumerate(parameters):
if i >= len(args):
break
kwargs[param.name] = args[i]
url = f"{self.base_url}/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
def convert(value):
if isinstance(value, list):
return [convert(v) for v in value]
elif isinstance(value, dict):
return {k: convert(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
return value.value
else:
return value
params = {}
data = {}
if webmethod.method == "GET":
params.update(kwargs)
else:
data.update(convert(kwargs))
ret = dict(
method=webmethod.method or "POST",
url=url,
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
timeout=30,
)
if params:
ret["params"] = params
if data:
ret["json"] = data
return ret
# Add protocol methods to the wrapper
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
# Name the class after the protocol
APIClient.__name__ = f"{protocol.__name__}Client"
_CLIENT_CLASSES[protocol] = APIClient
return APIClient
# not quite general these methods are
def extract_non_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if not issubclass(get_origin(arg) or arg, AsyncIterator):
return arg
return type_hint
def extract_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if issubclass(get_origin(arg) or arg, AsyncIterator):
inner_args = get_args(arg)
return inner_args[0]
return None
async def example(model: str = None):
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
from llama_stack.apis.inference.event_logger import EventLogger
client_class = create_api_client_class(Inference)
client = client_class("http://localhost:5003")
if not model:
model = "Llama3.2-3B-Instruct"
message = UserMessage(
content="hello world, write me a 2 sentence poem about the moon"
)
cprint(f"User>{message.content}", "green")
stream = True
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,
)
async for log in EventLogger().log(iterator):
log.print()
if __name__ == "__main__":
import asyncio
asyncio.run(example())

View file

@ -3,189 +3,190 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import textwrap
from typing import Any
from llama_models.sku_list import (
llama3_1_family,
llama3_2_family,
llama3_family,
resolve_model,
safety_models,
)
from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
from termcolor import cprint
from llama_stack.apis.memory.memory import MemoryBankType
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
stack_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.impls.meta_reference.safety.config import (
MetaReferenceShieldType,
)
ALLOWED_MODELS = (
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
)
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
logger = logging.getLogger(__name__)
def make_routing_entry_type(config_class: Any):
class BaseModelWithConfig(BaseModel):
routing_key: str
config: config_class
def configure_single_provider(
registry: Dict[str, ProviderSpec], provider: Provider
) -> Provider:
provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class)
try:
if provider.config:
existing = config_type(**provider.config)
else:
existing = None
except Exception:
existing = None
return BaseModelWithConfig
cfg = prompt_for_config(config_type, existing)
return Provider(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
config=cfg.dict(),
)
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
"""Get corresponding builtin APIs given provider backed APIs"""
res = []
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in provider_backed_apis:
res.append(inf.routing_table_api.value)
return res
# TODO: make sure we can deal with existing configuration values correctly
# instead of just overwriting them
def configure_api_providers(
config: StackRunConfig, spec: DistributionSpec
config: StackRunConfig, build_spec: DistributionSpec
) -> StackRunConfig:
apis = config.apis_to_serve or list(spec.providers.keys())
# append the bulitin routing APIs
apis += get_builtin_apis(apis)
is_nux = len(config.providers) == 0
router_api2builtin_api = {
inf.router_api.value: inf.routing_table_api.value
for inf in builtin_automatically_routed_apis()
}
if is_nux:
logger.info(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
we need to configure the providers (implementations) you want to use for these APIs.
"""
)
)
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
apis = [v.value for v in stack_apis()]
all_providers = get_provider_registry()
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
# configure simple case for with non-routing providers to api_providers
for api_str in spec.providers.keys():
if api_str not in apis:
for api_str in apis_to_serve:
api = Api(api_str)
if api in builtin_apis:
continue
if api not in provider_registry:
raise ValueError(f"Unknown API `{api_str}`")
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
api = Api(api_str)
p = spec.providers[api_str]
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
if isinstance(p, list):
cprint(
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
"yellow",
existing_providers = config.providers.get(api_str, [])
if existing_providers:
logger.info(
f"Re-configuring existing providers for API `{api_str}`...",
"green",
attrs=["bold"],
)
p = p[0]
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class)
try:
provider_config = config.api_providers.get(api_str)
if provider_config:
existing = config_type(**provider_config.config)
else:
existing = None
except Exception:
existing = None
cfg = prompt_for_config(config_type, existing)
if api_str in router_api2builtin_api:
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
routing_entries = []
if api_str == "inference":
if hasattr(cfg, "model"):
routing_key = cfg.model
else:
routing_key = prompt(
"> Please enter the supported model your provider has for inference: ",
default="Llama3.1-8B-Instruct",
validator=Validator.from_callable(
lambda x: resolve_model(x) is not None,
error_message="Model must be: {}".format(
[x.descriptor() for x in ALLOWED_MODELS]
),
),
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
)
updated_providers = []
for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(
configure_single_provider(provider_registry[api], p)
)
if api_str == "safety":
# TODO: add support for other safety providers, and simplify safety provider config
if p == "meta-reference":
routing_entries.append(
RoutableProviderConfig(
routing_key=[s.value for s in MetaReferenceShieldType],
provider_type=p,
config=cfg.dict(),
)
)
else:
cprint(
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
"yellow",
attrs=["bold"],
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
)
)
if api_str == "memory":
bank_types = list([x.value for x in MemoryBankType])
routing_key = prompt(
"> Please enter the supported memory bank type your provider has for memory: ",
default="vector",
validator=Validator.from_callable(
lambda x: x in bank_types,
error_message="Invalid provider, please enter one of the following: {}".format(
bank_types
),
),
)
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_type=p,
config=cfg.dict(),
)
)
config.routing_table[api_str] = routing_entries
config.api_providers[api_str] = PlaceholderProviderConfig(
providers=p if isinstance(p, list) else [p]
)
logger.info("")
else:
config.api_providers[api_str] = GenericProviderConfig(
provider_type=p,
config=cfg.dict(),
)
# we are newly configuring this API
plist = build_spec.providers.get(api_str, [])
plist = plist if isinstance(plist, list) else [plist]
print("")
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
updated_providers = []
for i, provider_type in enumerate(plist):
if i >= 1:
others = ", ".join(plist[i:])
logger.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
)
break
logger.info(f"> Configuring provider `({provider_type})`")
updated_providers.append(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(
f"{provider_type}-{i:02d}"
if len(plist) > 1
else provider_type
),
provider_type=provider_type,
config={},
),
)
)
logger.info("")
config.providers[api_str] = updated_providers
return config
def upgrade_from_routing_table(
config_dict: Dict[str, Any],
) -> Dict[str, Any]:
def get_providers(entries):
return [
Provider(
provider_id=(
f"{entry['provider_type']}-{i:02d}"
if len(entries) > 1
else entry["provider_type"]
),
provider_type=entry["provider_type"],
config=entry["config"],
)
for i, entry in enumerate(entries)
]
providers_by_api = {}
routing_table = config_dict.get("routing_table", {})
for api_str, entries in routing_table.items():
providers = get_providers(entries)
providers_by_api[api_str] = providers
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
if provider_map:
for api_str, provider in provider_map.items():
if isinstance(provider, dict) and "provider_type" in provider:
providers_by_api[api_str] = [
Provider(
provider_id=f"{provider['provider_type']}",
provider_type=provider["provider_type"],
config=provider["config"],
)
]
config_dict["providers"] = providers_by_api
config_dict.pop("routing_table", None)
config_dict.pop("api_providers", None)
config_dict.pop("provider_map", None)
config_dict["apis"] = config_dict["apis_to_serve"]
config_dict.pop("apis_to_serve", None)
return config_dict
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**config_dict)
if "routing_table" in config_dict:
logger.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
return StackRunConfig(**config_dict)

View file

@ -4,35 +4,62 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTaskInput
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = Union[str, List[str]]
class GenericProviderConfig(BaseModel):
provider_type: str
config: Dict[str, Any]
RoutableObject = Union[
Model,
Shield,
MemoryBank,
Dataset,
ScoringFn,
EvalTask,
]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
RoutableObjectWithProvider = Annotated[
Union[
Model,
Shield,
MemoryBank,
Dataset,
ScoringFn,
EvalTask,
],
Field(discriminator="type"),
]
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
RoutedProtocol = Union[
Inference,
Safety,
Memory,
DatasetIO,
Scoring,
Eval,
]
# Example: /inference, /safety
@ -53,18 +80,16 @@ class AutoRoutedProviderSpec(ProviderSpec):
# Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
router_api: Api
module: str
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type
class DistributionSpec(BaseModel):
description: Optional[str] = Field(
default="",
@ -80,10 +105,14 @@ in the runtime configuration to help route to the correct provider.""",
)
@json_schema_type
class Provider(BaseModel):
provider_id: str
provider_type: str
config: Dict[str, Any]
class StackRunConfig(BaseModel):
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
built_at: datetime
image_name: str = Field(
...,
@ -100,36 +129,34 @@ this could be just a hash
default=None,
description="Reference to the conda environment if this package refers to a conda environment",
)
apis_to_serve: List[str] = Field(
apis: List[str] = Field(
default_factory=list,
description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
api_providers: Dict[
str, Union[GenericProviderConfig, PlaceholderProviderConfig]
] = Field(
providers: Dict[str, List[Provider]] = Field(
description="""
Provider configurations for each of the APIs provided by this package.
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
can be instantiated multiple times (with different configs) if necessary.
""",
)
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
default_factory=dict,
metadata_store: Optional[KVStoreConfig] = Field(
default=None,
description="""
E.g. The following is a ProviderRoutingEntry for models:
- routing_key: Llama3.1-8B-Instruct
provider_type: meta-reference
config:
model: Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
""",
Configuration for the persistence store used by the distribution registry. If not specified,
a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: List[ModelInput] = Field(default_factory=list)
shields: List[ShieldInput] = Field(default_factory=list)
memory_banks: List[MemoryBankInput] = Field(default_factory=list)
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
@json_schema_type
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
name: str

View file

@ -9,7 +9,7 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]:
@ -35,6 +35,18 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
routing_table_api=Api.memory_banks,
router_api=Api.memory,
),
AutoRoutedApiInfo(
routing_table_api=Api.datasets,
router_api=Api.datasetio,
),
AutoRoutedApiInfo(
routing_table_api=Api.scoring_functions,
router_api=Api.scoring,
),
AutoRoutedApiInfo(
routing_table_api=Api.eval_tasks,
router_api=Api.eval,
),
]
@ -50,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
for api in providable_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {
"remote": remote_provider_spec(api),
**{a.provider_type: a for a in module.available_providers()},
}
ret[api] = {a.provider_type: a for a in module.available_providers()}
return ret

View file

@ -6,45 +6,58 @@
from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403
from pydantic import BaseModel
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
class DistributionInspectConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = DistributionInspectImpl(config, deps)
await impl.initialize()
return impl
class DistributionInspectImpl(Inspect):
def __init__(self):
def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
run_config = self.config.run_config
ret = {}
all_providers = get_provider_registry()
for api, providers in all_providers.items():
ret[api.value] = [
for api, providers in run_config.providers.items():
ret[api] = [
ProviderInfo(
provider_id=p.provider_id,
provider_type=p.provider_type,
description="Passthrough" if is_passthrough(p) else "",
)
for p in providers.values()
for p in providers
]
return ret
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
run_config = self.config.run_config
ret = {}
all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, [])
ret[api.value] = [
RouteInfo(
route=e.route,
method=e.method,
providers=[],
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
]

View file

@ -5,11 +5,14 @@
# the root directory of this source tree.
import json
import logging
import threading
from typing import Any, Dict
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
_THREAD_LOCAL = threading.local()
@ -32,7 +35,7 @@ class NeedsRequestProviderData:
provider_data = validator(**val)
return provider_data
except Exception as e:
print("Error parsing provider data", e)
log.error("Error parsing provider data", e)
def set_request_provider_data(headers: Dict[str, str]):
@ -51,7 +54,7 @@ def set_request_provider_data(headers: Dict[str, str]):
try:
val = json.loads(val)
except json.JSONDecodeError:
print("Provider data not encoded as a JSON object!", val)
log.error("Provider data not encoded as a JSON object!", val)
return
_THREAD_LOCAL.provider_data_header_value = val

View file

@ -4,159 +4,287 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import inspect
from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.inspect import DistributionInspectImpl
import logging
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTasks
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
class InvalidProviderError(Exception):
pass
def api_protocol_map() -> Dict[Api, Any]:
return {
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
Api.memory: Memory,
Api.memory_banks: MemoryBanks,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
Api.telemetry: Telemetry,
Api.datasetio: DatasetIO,
Api.datasets: Datasets,
Api.scoring: Scoring,
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
Api.eval_tasks: EvalTasks,
}
def additional_protocols_map() -> Dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (
ScoringFunctionsProtocolPrivate,
ScoringFunctions,
Api.scoring_functions,
),
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
}
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
class ProviderWithSpec(Provider):
spec: ProviderSpec
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = get_provider_registry()
specs = {}
configs = {}
for api_str, config in run_config.api_providers.items():
api = Api(api_str)
# TODO: check that these APIs are not in the routing table part of the config
providers = all_providers[api]
# skip checks for API whose provider config is specified in routing_table
if isinstance(config, PlaceholderProviderConfig):
continue
if config.provider_type not in providers:
raise ValueError(
f"Provider `{config.provider_type}` is not available for API `{api}`"
)
specs[api] = providers[config.provider_type]
configs[api] = config
apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_table.keys())
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
providers_with_specs = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
)
specs = {}
for provider in providers:
if provider.provider_type not in provider_registry[api]:
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
log.error(p.deprecation_error, "red", attrs=["bold"])
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
spec=p,
**(provider.model_dump()),
)
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys())
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
)
for info in builtin_automatically_routed_apis():
source_api = info.routing_table_api
assert (
source_api not in specs
), f"Routing table API {source_api} specified in wrong place?"
assert (
info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?"
if info.router_api.value not in apis_to_serve:
continue
if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
api=info.routing_table_api,
router_api=info.router_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
deps__=([f"inner-{info.router_api.value}"]),
),
)
}
routing_table = run_config.routing_table[info.router_api.value]
providers_with_specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
),
)
}
providers = all_providers[info.router_api]
inner_specs = []
inner_deps = []
for rt_entry in routing_table:
if rt_entry.provider_type not in providers:
raise ValueError(
f"Provider `{rt_entry.provider_type}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_type])
inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=inner_deps,
inner_specs=inner_specs,
sorted_providers = topological_sort(
{k: v.values() for k, v in providers_with_specs.items()}
)
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
"inspect",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={
"run_config": run_config.dict(),
},
spec=InlineProviderSpec(
api=Api.inspect,
provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
api_dependencies=apis,
deps__=([x.value for x in apis]),
),
),
)
configs[source_api] = routing_table
specs[info.router_api] = AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=source_api,
api_dependencies=[source_api],
)
configs[info.router_api] = {}
sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
print(f" {spec.api}: {spec.provider_type}")
print("")
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, configs[api])
impls[api] = impl
impls[Api.inspect] = DistributionInspectImpl()
specs[Api.inspect] = InlineProviderSpec(
api=Api.inspect,
provider_type="__distribution_builtin__",
config_class="",
module="",
)
return impls, specs
log.info(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
log.info(f" {api_str} => {provider.provider_id}")
log.info("")
impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[
f"inner-{provider.spec.router_api.value}"
]
impl = await instantiate_provider(
provider,
deps,
inner_impls,
dist_registry,
)
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str:
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else:
api = Api(api_str)
impls[api] = impl
return impls
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in providers}
def topological_sort(
providers_with_specs: Dict[str, List[ProviderWithSpec]],
) -> List[ProviderWithSpec]:
def dfs(kv, visited: Set[str], stack: List[str]):
api_str, providers = kv
visited.add(api_str)
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
deps = []
for provider in providers:
for dep in provider.spec.deps__:
deps.append(dep)
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
for dep in deps:
if dep not in visited:
dfs((dep, providers_with_specs[dep]), visited, stack)
stack.append(a.api)
stack.append(api_str)
visited = set()
stack = []
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
for api_str, providers in providers_with_specs.items():
if api_str not in visited:
dfs((api_str, providers), visited, stack)
return [by_id[x] for x in stack]
flattened = []
for api_str in stack:
for provider in providers_with_specs[api_str]:
flattened.append((api_str, provider))
return flattened
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
provider: ProviderWithSpec,
deps: Dict[str, Any],
provider_config: Union[GenericProviderConfig, RoutingTable],
inner_impls: Dict[str, Any],
dist_registry: DistributionRegistry,
):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
provider_spec = provider.spec
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
config = config_type(**provider.config)
method = "get_adapter_impl"
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
@ -165,31 +293,95 @@ async def instantiate_provider(
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_type],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
args = [provider_spec.api, inner_impls, deps, dist_registry]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
config = config_type(**provider.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_id__ = provider.provider_id
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
check_protocol_compliance(impl, protocols[provider_spec.api])
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
missing_methods = []
mro = type(obj).__mro__
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
if not hasattr(obj, name):
missing_methods.append((name, "missing"))
elif not callable(getattr(obj, name)):
missing_methods.append((name, "not_callable"))
else:
# Check if the method signatures are compatible
obj_method = getattr(obj, name)
proto_sig = inspect.signature(value)
obj_sig = inspect.signature(obj_method)
proto_params = set(proto_sig.parameters)
proto_params.discard("self")
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
log.error(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next(
(cls for cls in mro if name in cls.__dict__), None
)
if (
method_owner is None
or method_owner.__name__ == protocol.__name__
):
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:
raise ValueError(
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
)
async def resolve_remote_stack_impls(
config: RemoteProviderConfig,
apis: List[str],
) -> Dict[Api, Any]:
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
impls = {}
for api_str in apis:
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
config,
{},
)
if api in additional_protocols:
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
config,
{},
)
return impls

View file

@ -4,43 +4,62 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Tuple
from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.store import DistributionRegistry
from .routing_tables import (
DatasetsRoutingTable,
EvalTasksRoutingTable,
MemoryBanksRoutingTable,
ModelsRoutingTable,
ScoringFunctionsRoutingTable,
ShieldsRoutingTable,
)
async def get_routing_table_impl(
api: Api,
inner_impls: List[Tuple[str, Any]],
routing_table_config: Dict[str, List[RoutableProviderConfig]],
impls_by_provider_id: Dict[str, RoutedProtocol],
_deps,
dist_registry: DistributionRegistry,
) -> Any:
from .routing_tables import (
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
)
api_to_tables = {
"memory_banks": MemoryBanksRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
"eval_tasks": EvalTasksRoutingTable,
}
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](inner_impls, routing_table_config)
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
await impl.initialize()
return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
from .routers import (
DatasetIORouter,
EvalRouter,
InferenceRouter,
MemoryRouter,
SafetyRouter,
ScoringRouter,
)
api_to_routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
"safety": SafetyRouter,
"datasetio": DatasetIORouter,
"scoring": ScoringRouter,
"eval": EvalRouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")

View file

@ -4,24 +4,27 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.datasetio.datasetio import DatasetIO
from llama_stack.apis.memory_banks.memory_banks import BankParams
from llama_stack.distribution.datatypes import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank type"""
"""Routes to an provider based on the memory bank identifier"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
self.bank_id_to_type = {}
async def initialize(self) -> None:
pass
@ -29,32 +32,19 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None:
pass
def get_provider_from_bank_id(self, bank_id: str) -> Any:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.routing_table.get_provider_impl(bank_type)
if not provider:
raise ValueError(f"Could not find provider for {bank_type}")
return provider
async def create_memory_bank(
async def register_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_type = config.type
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
name, config, url
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memorybank_id: Optional[str] = None,
) -> None:
await self.routing_table.register_memory_bank(
memory_bank_id,
params,
provider_id,
provider_memorybank_id,
)
self.bank_id_to_type[bank.bank_id] = bank_type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
provider = self.get_provider_from_bank_id(bank_id)
return await provider.get_memory_bank(bank_id)
async def insert_documents(
self,
@ -62,7 +52,7 @@ class MemoryRouter(Memory):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
return await self.get_provider_from_bank_id(bank_id).insert_documents(
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
bank_id, documents, ttl_seconds
)
@ -72,7 +62,7 @@ class MemoryRouter(Memory):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
return await self.get_provider_from_bank_id(bank_id).query_documents(
return await self.routing_table.get_provider_impl(bank_id).query_documents(
bank_id, query, params
)
@ -92,11 +82,23 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None:
pass
async def register_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata
)
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
@ -104,44 +106,52 @@ class InferenceRouter(Inference):
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
params = dict(
model=model,
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
**params
):
yield chunk
provider = self.routing_table.get_provider_impl(model_id)
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
else:
return await provider.chat_completion(**params)
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(model).completion(
model=model,
) -> AsyncGenerator:
provider = self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return (chunk async for chunk in await provider.completion(**params))
else:
return await provider.completion(**params)
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(model).embeddings(
model=model,
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,
)
@ -159,14 +169,178 @@ class SafetyRouter(Safety):
async def shutdown(self) -> None:
pass
async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
return await self.routing_table.register_shield(
shield_id, provider_shield_id, provider_id, params
)
async def run_shield(
self,
shield_type: str,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
return await self.routing_table.get_provider_impl(shield_type).run_shield(
shield_type=shield_type,
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)
class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
return await self.routing_table.get_provider_impl(
dataset_id
).get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
filter_condition=filter_condition,
)
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_eval(
self,
task_id: str,
task_config: AppEvalTaskConfig,
) -> Job:
return await self.routing_table.get_provider_impl(task_id).run_eval(
task_id=task_id,
task_config=task_config,
)
@webmethod(route="/eval/evaluate_rows", method="POST")
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
task_id=task_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
)
async def job_status(
self,
task_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(task_id).job_status(
task_id, job_id
)
async def job_cancel(
self,
task_id: str,
job_id: str,
) -> None:
await self.routing_table.get_provider_impl(task_id).job_cancel(
task_id,
job_id,
)
async def job_result(
self,
task_id: str,
job_id: str,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).job_result(
task_id,
job_id,
)

View file

@ -4,141 +4,427 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional
from pydantic import parse_obj_as
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_models.llama3.api.datatypes import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.datatypes import * # noqa: F403
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
return await p.register_shield(obj)
elif api == Api.memory:
return await p.register_memory_bank(obj)
elif api == Api.datasetio:
return await p.register_dataset(obj)
elif api == Api.scoring:
return await p.register_scoring_function(obj)
elif api == Api.eval:
return await p.register_eval_task(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.memory:
return await p.unregister_memory_bank(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
Registry = Dict[str, List[RoutableObjectWithProvider]]
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
inner_impls: List[Tuple[RoutingKey, Any]],
routing_table_config: Dict[str, List[RoutableProviderConfig]],
impls_by_provider_id: Dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
) -> None:
self.unique_providers = []
self.providers = {}
self.routing_keys = []
for key, impl in inner_impls:
keys = key if isinstance(key, list) else [key]
self.unique_providers.append((keys, impl))
for k in keys:
if k in self.providers:
raise ValueError(f"Duplicate routing key {k}")
self.providers[k] = impl
self.routing_keys.append(k)
self.routing_table_config = routing_table_config
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
async def initialize(self) -> None:
for keys, p in self.unique_providers:
spec = p.__provider_spec__
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
continue
await p.validate_routing_keys(keys)
async def add_objects(
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
else:
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.memory:
p.memory_bank_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.eval_task_store = self
async def shutdown(self) -> None:
for _, p in self.unique_providers:
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.providers:
raise ValueError(f"Could not find provider for {routing_key}")
return self.providers[routing_key]
def get_provider_impl(
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, MemoryBanksRoutingTable):
return ("Memory", "memory_bank")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, EvalTasksRoutingTable):
return ("Eval", "eval_task")
else:
raise ValueError("Unknown routing table type")
def get_routing_keys(self) -> List[str]:
return self.routing_keys
apiname, objtype = apiname_object()
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
for entry in self.routing_table_config:
if entry.routing_key == routing_key:
return entry
return None
# Get objects from disk registry
obj = self.dist_registry.get_cached(objtype, routing_key)
if not obj:
provider_ids = list(self.impls_by_provider_id.keys())
if len(provider_ids) > 1:
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
else:
provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError(
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
)
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(
obj, self.impls_by_provider_id[obj.provider_id]
)
async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
# Get existing objects from registry
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type == type]
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[Model]:
return await self.get_all_with_type("model")
async def list_models(self) -> List[ModelServingSpec]:
specs = []
for entry in self.routing_table_config:
model_id = entry.routing_key
specs.append(
ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=entry,
)
)
return specs
async def get_model(self, identifier: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", identifier)
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
for entry in self.routing_table_config:
if entry.routing_key == core_model_id:
return ModelServingSpec(
llama_model=resolve_model(core_model_id),
provider_config=entry,
async def register_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
return None
if metadata is None:
metadata = {}
model = Model(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
)
registered_model = await self.register_object(model)
return registered_model
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[Shield]:
return await self.get_all_with_type(ResourceType.shield.value)
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config:
if isinstance(entry.routing_key, list):
for k in entry.routing_key:
specs.append(
ShieldSpec(
shield_type=k,
provider_config=entry,
)
)
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier)
async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
if provider_shield_id is None:
provider_shield_id = shield_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
for entry in self.routing_table_config:
if entry.routing_key == shield_type:
return ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
return None
if params is None:
params = {}
shield = Shield(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
params=params,
)
await self.register_object(shield)
return shield
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBank]:
return await self.get_all_with_type(ResourceType.memory_bank.value)
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
specs = []
for entry in self.routing_table_config:
specs.append(
MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
for entry in self.routing_table_config:
if entry.routing_key == bank_type:
return MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
async def register_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
if provider_memory_bank_id is None:
provider_memory_bank_id = memory_bank_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
return None
memory_bank = parse_obj_as(
MemoryBank,
{
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
},
)
await self.register_object(memory_bank)
return memory_bank
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
existing_bank = await self.get_memory_bank(memory_bank_id)
if existing_bank is None:
raise ValueError(f"Memory bank {memory_bank_id} not found")
await self.unregister_object(existing_bank)
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[Dataset]:
return await self.get_all_with_type(ResourceType.dataset.value)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
async def register_dataset(
self,
dataset_id: str,
dataset_schema: Dict[str, ParamType],
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
if provider_dataset_id is None:
provider_dataset_id = dataset_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this dataset
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if metadata is None:
metadata = {}
dataset = Dataset(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
dataset_schema=dataset_schema,
url=url,
metadata=metadata,
)
await self.register_object(dataset)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:
return await self.get_all_with_type(ResourceType.scoring_function.value)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None:
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFn(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
provider_id=provider_id,
params=params,
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
async def list_eval_tasks(self) -> List[EvalTask]:
return await self.get_all_with_type(ResourceType.eval_task.value)
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier("eval_task", name)
async def register_eval_task(
self,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
metadata: Optional[Dict[str, Any]] = None,
provider_eval_task_id: Optional[str] = None,
provider_id: Optional[str] = None,
) -> None:
if metadata is None:
metadata = {}
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if provider_eval_task_id is None:
provider_eval_task_id = eval_task_id
eval_task = EvalTask(
identifier=eval_task_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_eval_task_id,
)
await self.register_object(eval_task)

View file

@ -9,15 +9,9 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api
@ -31,18 +25,7 @@ class ApiEndpoint(BaseModel):
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
Api.inspect: Inspect,
}
protocols = api_protocol_map()
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
@ -52,7 +35,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
continue
webmethod = method.__webmethod__
route = webmethod.route
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == "GET":
method = "get"

View file

@ -4,62 +4,69 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import functools
import inspect
import json
import os
import signal
import sys
import traceback
import warnings
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from ssl import SSLError
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
from pathlib import Path
from typing import Any, Union
import fire
import httpx
import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
SpanStatus,
start_trace,
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import (
construct_stack,
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.inline.meta_reference.telemetry.console import (
ConsoleConfig,
ConsoleTelemetryImpl,
)
from .endpoints import get_all_api_endpoints
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
REPO_ROOT = Path(__file__).parent.parent.parent.parent
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, "write") else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(message, category, filename, lineno, line))
if os.environ.get("LLAMA_STACK_TRACE_WARNINGS"):
warnings.showwarning = warn_with_traceback
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
data = data.model_dump_json()
else:
data = json.dumps(data)
@ -108,72 +115,20 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
)
async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
await start_trace(request.path, {"downstream_url": downstream_url})
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
content = await request.body()
client = httpx.AsyncClient()
erred = False
try:
req = client.build_request(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
)
response = await client.send(req, stream=True)
async def stream_response():
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk
await response.aclose()
await client.aclose()
return StreamingResponse(
stream_response(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)
except httpx.ReadTimeout:
erred = True
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
erred = True
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
erred = True
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
erred = True
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
erred = True
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
erred = True
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
finally:
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(*args, **kwargs):
def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")
async def run_shutdown():
for impl in app.__llama_stack_impls__.values():
print(f"Shutting down {impl}")
await impl.shutdown()
asyncio.run(run_shutdown())
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
@ -182,76 +137,57 @@ async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
return endpoint
async def maybe_await(value):
if inspect.iscoroutine(value):
return await value
return value
async def sse_generator(event_gen):
try:
event_gen = await event_gen
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
response_model = hints.get("return")
async def endpoint(request: Request, **kwargs):
set_request_provider_data(request.headers)
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
if is_streaming:
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
finally:
await end_trace()
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
try:
return (
await func(**kwargs)
if asyncio.iscoroutinefunction(func)
else func(**kwargs)
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
sig = inspect.signature(func)
new_params = [
@ -275,54 +211,118 @@ def create_dynamic_typed_route(func: Any, method: str):
return endpoint
def main(
yaml_config: str = "llamastack-run.yaml",
port: int = 5000,
disable_ipv6: bool = False,
):
with open(yaml_config, "r") as fp:
config = StackRunConfig(**yaml.safe_load(fp))
class TracingMiddleware:
def __init__(self, app):
self.app = app
app = FastAPI()
async def __call__(self, scope, receive, send):
path = scope["path"]
await start_trace(path, {"location": "server"})
try:
return await self.app(scope, receive, send)
finally:
await end_trace()
def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
"--yaml-config",
help="Path to YAML configuration file",
)
parser.add_argument(
"--template",
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
)
parser.add_argument("--port", type=int, default=5000, help="Port to listen on")
parser.add_argument(
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
args = parser.parse_args()
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
print(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
print(f"Error: {str(e)}")
sys.exit(1)
if args.yaml_config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
print(f"Using config file: {config_file}")
elif args.template:
config_file = (
Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
)
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
print(f"Using template {args.template} config file: {config_file}")
else:
raise ValueError("Either --yaml-config or --template must be provided")
with open(config_file, "r") as fp:
config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config)
print("Run configuration:")
print(yaml.dump(config.model_dump(), indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError:
sys.exit(1)
impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(ConsoleTelemetryImpl(ConsoleConfig()))
all_endpoints = get_all_api_endpoints()
if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
apis_to_serve.add(Api.inspect)
for inf in builtin_automatically_routed_apis():
# if we do not serve the corresponding router API, we should not serve the routing table API
if inf.router_api.value not in apis_to_serve:
continue
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
impl = impls[api]
provider_spec = specs[api]
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
for endpoint in endpoints:
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
impl_method = getattr(impl, endpoint.name)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=UserWarning, module="pydantic._internal._fields"
)
else:
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(
f"Could not find method {endpoint.name} on {impl}!!"
)
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
@ -337,15 +337,18 @@ def main(
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
app.__llama_stack_impls__ = impls
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{args.port}")
uvicorn.run(app, host=listen_host, port=args.port)
if __name__ == "__main__":
fire.Fire(main)
main()

View file

@ -0,0 +1,203 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from pathlib import Path
from typing import Any, Dict
import pkg_resources
import yaml
from termcolor import colored
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
LLAMA_STACK_API_VERSION = "alpha"
class LlamaStack(
MemoryBanks,
Inference,
BatchInference,
Agents,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
Memory,
Eval,
EvalTasks,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
Inspect,
):
pass
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
Api.scoring_functions,
"register_scoring_function",
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
]
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
if api not in impls:
continue
method = getattr(impls[api], register_method)
for obj in objects:
await method(**obj.model_dump())
method = getattr(impls[api], list_method)
for obj in await method():
log.info(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
)
log.info("")
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
self.path = path
super().__init__(
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
)
def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict):
result = {}
for k, v in config.items():
try:
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result
elif isinstance(config, list):
result = []
for i, v in enumerate(config):
try:
result.append(replace_env_vars(v, f"{path}[{i}]"))
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result
elif isinstance(config, str):
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
def get_env_var(match):
env_var = match.group(1)
default_val = match.group(2)
value = os.environ.get(env_var)
if not value:
if default_val is None:
raise EnvVarError(env_var, path)
else:
value = default_val
# expand "~" from the values
return os.path.expanduser(value)
try:
return re.sub(pattern, get_env_var, config)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return config
def validate_env_pair(env_pair: str) -> tuple[str, str]:
"""Validate and split an environment variable key-value pair."""
try:
key, value = env_pair.split("=", 1)
key = key.strip()
if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key):
raise ValueError(
f"Key must contain only alphanumeric characters and underscores: {key}"
)
return key, value
except ValueError as e:
raise ValueError(
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
) from e
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(), dist_registry
)
await register_resources(run_config, impls)
return impls
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = pkg_resources.resource_filename(
"llama_stack", f"templates/{template}/run.yaml"
)
if not Path(template_path).exists():
raise ValueError(f"Template '{template}' not found at {template_path}")
with open(template_path) as f:
run_config = yaml.safe_load(f)
return StackRunConfig(**replace_env_vars(run_config))

View file

@ -33,10 +33,33 @@ shift
port="$1"
shift
# Process environment variables from --env arguments
env_vars=""
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
if [[ -n "$2" ]]; then
# collect environment variables so we can set them after activating the conda env
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
shift
;;
esac
done
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
set -x
$CONDA_PREFIX/bin/python \
-m llama_stack.distribution.server.server \
--yaml_config "$yaml_config" \
--port "$port" "$@"
--yaml-config "$yaml_config" \
--port "$port" \
$env_vars

View file

@ -10,6 +10,8 @@ DOCKER_BINARY=${DOCKER_BINARY:-docker}
DOCKER_OPTS=${DOCKER_OPTS:-}
LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
set -euo pipefail
@ -29,7 +31,7 @@ if [ $# -lt 3 ]; then
fi
build_name="$1"
docker_image="llamastack-$build_name"
docker_image="localhost/distribution-$build_name"
shift
yaml_config="$1"
@ -38,6 +40,26 @@ shift
port="$1"
shift
# Process environment variables from --env arguments
env_vars=""
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
echo "env = $2"
if [[ -n "$2" ]]; then
env_vars="$env_vars -e $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
shift
;;
esac
done
set -x
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
@ -54,11 +76,21 @@ if [ -n "$LLAMA_CHECKPOINT_DIR" ]; then
DOCKER_OPTS="$DOCKER_OPTS --gpus=all"
fi
version_tag="latest"
if [ -n "$PYPI_VERSION" ]; then
version_tag="$PYPI_VERSION"
elif [ -n "$LLAMA_STACK_DIR" ]; then
version_tag="dev"
elif [ -n "$TEST_PYPI_VERSION" ]; then
version_tag="test-$TEST_PYPI_VERSION"
fi
$DOCKER_BINARY run $DOCKER_OPTS -it \
-p $port:$port \
$env_vars \
-v "$yaml_config:/app/config.yaml" \
$mounts \
$docker_image \
$docker_image:$version_tag \
python -m llama_stack.distribution.server.server \
--yaml_config /app/config.yaml \
--port $port "$@"
--yaml-config /app/config.yaml \
--port "$port"

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .registry import * # noqa: F401 F403

View file

@ -0,0 +1,221 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from contextlib import asynccontextmanager
from typing import Dict, List, Optional, Protocol, Tuple
import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import (
KVStore,
kvstore_impl,
SqliteKVStoreConfig,
)
class DistributionRegistry(Protocol):
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
async def initialize(self) -> None: ...
async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
async def update(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider: ...
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
async def delete(self, type: str, identifier: str) -> None: ...
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v2"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
def _get_registry_key_range() -> Tuple[str, str]:
"""Returns the start and end keys for the registry range query."""
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
return start_key, f"{start_key}\xff"
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
obj = pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(value),
)
all_objects.append(obj)
return all_objects
class DiskDistributionRegistry(DistributionRegistry):
def __init__(self, kvstore: KVStore):
self.kvstore = kvstore
async def initialize(self) -> None:
pass
def get_cached(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Disk registry does not have a cache
raise NotImplementedError("Disk registry does not have a cache")
async def get_all(self) -> List[RoutableObjectWithProvider]:
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
return _parse_registry_values(values)
async def get(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(
KEY_FORMAT.format(type=type, identifier=identifier)
)
if not json_str:
return None
objects_data = json.loads(json_str)
# Return only the first object if any exist
if objects_data:
return pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(objects_data),
)
return None
async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
obj.model_dump_json(),
)
return obj
async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_obj = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists
if existing_obj and existing_obj.provider_id == obj.provider_id:
return False
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
obj.model_dump_json(),
)
return True
async def delete(self, type: str, identifier: str) -> None:
await self.kvstore.delete(KEY_FORMAT.format(type=type, identifier=identifier))
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore):
super().__init__(kvstore)
self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {}
self._initialized = False
self._initialize_lock = asyncio.Lock()
self._cache_lock = asyncio.Lock()
@asynccontextmanager
async def _locked_cache(self):
"""Context manager for safely accessing the cache with a lock."""
async with self._cache_lock:
yield self.cache
async def _ensure_initialized(self):
"""Ensures the registry is initialized before operations."""
if self._initialized:
return
async with self._initialize_lock:
if self._initialized:
return
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
objects = _parse_registry_values(values)
async with self._locked_cache() as cache:
for obj in objects:
cache_key = (obj.type, obj.identifier)
cache[cache_key] = obj
self._initialized = True
async def initialize(self) -> None:
await self._ensure_initialized()
def get_cached(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
return self.cache.get((type, identifier), None)
async def get_all(self) -> List[RoutableObjectWithProvider]:
await self._ensure_initialized()
async with self._locked_cache() as cache:
return list(cache.values())
async def get(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
await self._ensure_initialized()
cache_key = (type, identifier)
async with self._locked_cache() as cache:
return cache.get(cache_key, None)
async def register(self, obj: RoutableObjectWithProvider) -> bool:
await self._ensure_initialized()
success = await super().register(obj)
if success:
cache_key = (obj.type, obj.identifier)
async with self._locked_cache() as cache:
cache[cache_key] = obj
return success
async def update(self, obj: RoutableObjectWithProvider) -> None:
await super().update(obj)
cache_key = (obj.type, obj.identifier)
async with self._locked_cache() as cache:
cache[cache_key] = obj
return obj
async def delete(self, type: str, identifier: str) -> None:
await super().delete(type, identifier)
cache_key = (type, identifier)
async with self._locked_cache() as cache:
if cache_key in cache:
del cache[cache_key]
async def create_dist_registry(
metadata_store: Optional[KVStoreConfig],
image_name: str,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if metadata_store:
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
)
)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
await dist_registry.initialize()
return dist_registry, dist_kvstore

View file

@ -0,0 +1,215 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.distribution.store import * # noqa F403
from llama_stack.apis.inference import Model
from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from llama_stack.distribution.datatypes import * # noqa F403
@pytest.fixture
def config():
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
if os.path.exists(config.db_path):
os.remove(config.db_path)
return config
@pytest_asyncio.fixture
async def registry(config):
registry = DiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize()
return registry
@pytest_asyncio.fixture
async def cached_registry(config):
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize()
return registry
@pytest.fixture
def sample_bank():
return VectorMemoryBank(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
provider_resource_id="test_bank",
provider_id="test-provider",
)
@pytest.fixture
def sample_model():
return Model(
identifier="test_model",
provider_resource_id="test_model",
provider_id="test-provider",
)
@pytest.mark.asyncio
async def test_registry_initialization(registry):
# Test empty registry
results = await registry.get("nonexistent", "nonexistent")
assert len(results) == 0
@pytest.mark.asyncio
async def test_basic_registration(registry, sample_bank, sample_model):
print(f"Registering {sample_bank}")
await registry.register(sample_bank)
print(f"Registering {sample_model}")
await registry.register(sample_model)
print("Getting bank")
results = await registry.get("memory_bank", "test_bank")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == sample_bank.identifier
assert result_bank.embedding_model == sample_bank.embedding_model
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
assert result_bank.provider_id == sample_bank.provider_id
results = await registry.get("model", "test_model")
assert len(results) == 1
result_model = results[0]
assert result_model.identifier == sample_model.identifier
assert result_model.provider_id == sample_model.provider_id
@pytest.mark.asyncio
async def test_cached_registry_initialization(config, sample_bank, sample_model):
# First populate the disk registry
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
await disk_registry.initialize()
await disk_registry.register(sample_bank)
await disk_registry.register(sample_model)
# Test cached version loads from disk
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
results = await cached_registry.get("memory_bank", "test_bank")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == sample_bank.identifier
assert result_bank.embedding_model == sample_bank.embedding_model
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
assert result_bank.provider_id == sample_bank.provider_id
@pytest.mark.asyncio
async def test_cached_registry_updates(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
new_bank = VectorMemoryBank(
identifier="test_bank_2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
overlap_size_in_tokens=32,
provider_resource_id="test_bank_2",
provider_id="baz",
)
await cached_registry.register(new_bank)
# Verify in cache
results = await cached_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == new_bank.identifier
assert result_bank.provider_id == new_bank.provider_id
# Verify persisted to disk
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
await new_registry.initialize()
results = await new_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == new_bank.identifier
assert result_bank.provider_id == new_bank.provider_id
@pytest.mark.asyncio
async def test_duplicate_provider_registration(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
original_bank = VectorMemoryBank(
identifier="test_bank_2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
overlap_size_in_tokens=32,
provider_resource_id="test_bank_2",
provider_id="baz",
)
await cached_registry.register(original_bank)
duplicate_bank = VectorMemoryBank(
identifier="test_bank_2",
embedding_model="different-model",
chunk_size_in_tokens=128,
overlap_size_in_tokens=16,
provider_resource_id="test_bank_2",
provider_id="baz", # Same provider_id
)
await cached_registry.register(duplicate_bank)
results = await cached_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1 # Still only one result
assert (
results[0].embedding_model == original_bank.embedding_model
) # Original values preserved
@pytest.mark.asyncio
async def test_get_all_objects(config):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
# Create multiple test banks
test_banks = [
VectorMemoryBank(
identifier=f"test_bank_{i}",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=256,
overlap_size_in_tokens=32,
provider_resource_id=f"test_bank_{i}",
provider_id=f"provider_{i}",
)
for i in range(3)
]
# Register all banks
for bank in test_banks:
await cached_registry.register(bank)
# Test get_all retrieval
all_results = await cached_registry.get_all()
assert len(all_results) == 3
# Verify each bank was stored correctly
for original_bank in test_banks:
matching_banks = [
b for b in all_results if b.identifier == original_bank.identifier
]
assert len(matching_banks) == 1
stored_bank = matching_banks[0]
assert stored_bank.embedding_model == original_bank.embedding_model
assert stored_bank.provider_id == original_bank.provider_id
assert stored_bank.chunk_size_in_tokens == original_bank.chunk_size_in_tokens
assert (
stored_bank.overlap_size_in_tokens == original_bank.overlap_size_in_tokens
)

View file

@ -1,15 +0,0 @@
name: local-cpu
distribution_spec:
description: remote inference + local safety/agents/memory
docker_image: null
providers:
inference:
- remote::ollama
- remote::tgi
- remote::together
- remote::fireworks
safety: meta-reference
agents: meta-reference
memory: meta-reference
telemetry: meta-reference
image_type: docker

View file

@ -1,49 +0,0 @@
built_at: '2024-09-30T09:04:30.533391'
image_name: local-cpu
docker_image: local-cpu
conda_env: null
apis_to_serve:
- agents
- inference
- models
- memory
- safety
- shields
- memory_banks
api_providers:
inference:
providers:
- remote::ollama
safety:
providers:
- meta-reference
agents:
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
memory:
providers:
- meta-reference
telemetry:
provider_type: meta-reference
config: {}
routing_table:
inference:
- provider_type: remote::ollama
config:
host: localhost
port: 6000
routing_key: Llama3.1-8B-Instruct
safety:
- provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_type: meta-reference
config: {}
routing_key: vector

View file

@ -1,11 +0,0 @@
name: local-gpu
distribution_spec:
description: local meta reference
docker_image: null
providers:
inference: meta-reference
safety: meta-reference
agents: meta-reference
memory: meta-reference
telemetry: meta-reference
image_type: docker

View file

@ -1,52 +0,0 @@
built_at: '2024-09-30T09:00:56.693751'
image_name: local-gpu
docker_image: local-gpu
conda_env: null
apis_to_serve:
- memory
- inference
- agents
- shields
- safety
- models
- memory_banks
api_providers:
inference:
providers:
- meta-reference
safety:
providers:
- meta-reference
agents:
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
memory:
providers:
- meta-reference
telemetry:
provider_type: meta-reference
config: {}
routing_table:
inference:
- provider_type: meta-reference
config:
model: Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
routing_key: Llama3.1-8B-Instruct
safety:
- provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_type: meta-reference
config: {}
routing_key: vector

View file

@ -1,10 +0,0 @@
name: local-bedrock-conda-example
distribution_spec:
description: Use Amazon Bedrock APIs.
providers:
inference: remote::bedrock
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local
distribution_spec:
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-databricks
distribution_spec:
description: Use Databricks for running LLM inference
providers:
inference: remote::databricks
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-fireworks
distribution_spec:
description: Use Fireworks.ai for running LLM inference
providers:
inference: remote::fireworks
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-hf-endpoint
distribution_spec:
description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints."
providers:
inference: remote::hf::endpoint
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-hf-serverless
distribution_spec:
description: "Like local, but use Hugging Face Inference API (serverless) for running LLM inference.\nSee https://hf.co/docs/api-inference."
providers:
inference: remote::hf::serverless
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-ollama
distribution_spec:
description: Like local, but use ollama for running LLM inference
providers:
inference: remote::ollama
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-tgi
distribution_spec:
description: Like local, but use a TGI server for running LLM inference.
providers:
inference: remote::tgi
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-together
distribution_spec:
description: Use Together.ai for running LLM inference
providers:
inference: remote::together
memory: meta-reference
safety: remote::together
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-vllm
distribution_spec:
description: Like local, but use vLLM for running LLM inference
providers:
inference: vllm
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import errno
import logging
import os
import pty
import select
@ -13,7 +14,7 @@ import subprocess
import sys
import termios
from termcolor import cprint
log = logging.getLogger(__name__)
# run a command in a pseudo-terminal, with interrupt handling,
@ -29,7 +30,7 @@ def run_with_pty(command):
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
ctrl_c_pressed = True
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
log.info("\nCtrl-C detected. Aborting...")
try:
# Set up the signal handler
@ -100,6 +101,6 @@ def run_command(command):
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
if process.returncode != 0:
print(f"Error: {error.decode('utf-8')}")
log.error(f"Error: {error.decode('utf-8')}")
sys.exit(1)
return output.decode("utf-8")

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
from .config_dirs import DEFAULT_CHECKPOINT_DIR
def model_local_dir(descriptor: str) -> str:
return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)
return str(Path(DEFAULT_CHECKPOINT_DIR) / (descriptor.replace(":", "-")))

View file

@ -6,6 +6,7 @@
import inspect
import json
import logging
from enum import Enum
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
@ -16,6 +17,8 @@ from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated
log = logging.getLogger(__name__)
def is_list_of_primitives(field_type):
"""Check if a field type is a List of primitive types."""
@ -111,7 +114,7 @@ def prompt_for_discriminated_union(
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:")
log.info(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (
getattr(existing_value, discriminator) != discriminator_value
@ -123,7 +126,7 @@ def prompt_for_discriminated_union(
setattr(sub_config, discriminator, discriminator_value)
return sub_config
else:
print(f"Invalid {discriminator}. Please try again.")
log.error(f"Invalid {discriminator}. Please try again.")
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
@ -180,7 +183,7 @@ def prompt_for_config(
config_data[field_name] = validated_value
break
except KeyError:
print(
log.error(
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
)
continue
@ -197,7 +200,7 @@ def prompt_for_config(
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
print(f"Entering sub-configuration for {field_name}:")
log.info(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif is_optional(field_type) and is_discriminated_union(
get_non_none_type(field_type)
@ -213,7 +216,7 @@ def prompt_for_config(
existing_value,
)
elif can_recurse(field_type):
print(f"\nEntering sub-configuration for {field_name}:")
log.info(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(
field_type,
existing_value,
@ -240,7 +243,7 @@ def prompt_for_config(
config_data[field_name] = None
break
else:
print("This field is required. Please provide a value.")
log.error("This field is required. Please provide a value.")
continue
else:
try:
@ -264,12 +267,12 @@ def prompt_for_config(
value = [element_type(item) for item in value]
except json.JSONDecodeError:
print(
log.error(
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
)
continue
except ValueError as e:
print(f"{str(e)}")
log.error(f"{str(e)}")
continue
elif get_origin(field_type) is dict:
@ -281,7 +284,7 @@ def prompt_for_config(
)
except json.JSONDecodeError:
print(
log.error(
"Invalid JSON. Please enter a valid JSON-encoded dict."
)
continue
@ -298,7 +301,7 @@ def prompt_for_config(
value = field_type(user_input)
except ValueError:
print(
log.error(
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
)
continue
@ -311,6 +314,6 @@ def prompt_for_config(
config_data[field_name] = validated_value
break
except ValueError as e:
print(f"Validation error: {str(e)}")
log.error(f"Validation error: {str(e)}")
return config_type(**config_data)

View file

@ -1,257 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from openai import OpenAI
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import DatabricksImplConfig
DATABRICKS_SUPPORTED_MODELS = {
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
}
class DatabricksInferenceAdapter(Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> OpenAI:
return OpenAI(
base_url=self.config.url,
api_key=self.config.api_token
)
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_databricks_messages(self, messages: list[Message]) -> list:
databricks_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
databricks_messages.append({"role": role, "content": message.content})
return databricks_messages
def resolve_databricks_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in DATABRICKS_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(DATABRICKS_SUPPORTED_MODELS.keys())}"
return DATABRICKS_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
options = self.get_databricks_chat_options(request)
databricks_model = self.resolve_databricks_model(request.model)
if not request.stream:
r = self.client.chat.completions.create(
model=databricks_model,
messages=self._messages_to_databricks_messages(messages),
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if r.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
for chunk in self.client.chat.completions.create(
model=databricks_model,
messages=self._messages_to_databricks_messages(messages),
stream=True,
**options,
):
if chunk.choices[0].finish_reason:
if (
stop_reason is None
and chunk.choices[0].finish_reason == "stop"
):
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].delta.content
if text is None:
continue
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -1,247 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
}
class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
def __init__(self, config: FireworksImplConfig) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> Fireworks:
return Fireworks(api_key=self.config.api_key)
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:
fireworks_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
fireworks_messages.append({"role": role, "content": message.content})
return fireworks_messages
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to fireworks
options = self.get_fireworks_chat_options(request)
fireworks_model = self.map_to_provider_model(request.model)
if not request.stream:
r = await self.client.chat.completions.acreate(
model=fireworks_model,
messages=self._messages_to_fireworks_messages(messages),
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if r.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in self.client.chat.completions.acreate(
model=fireworks_model,
messages=self._messages_to_fireworks_messages(messages),
stream=True,
**options,
):
if chunk.choices[0].finish_reason:
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].delta.content
if text is None:
continue
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -1,266 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
}
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
def __init__(self, url: str) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
)
self.url = url
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.url)
async def initialize(self) -> None:
print("Initializing Ollama, checking connectivity to server...")
try:
await self.client.ps()
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
) from e
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
ollama_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
ollama_messages.append({"role": role, "content": message.content})
return ollama_messages
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
if (
request.sampling_params.repetition_penalty is not None
and request.sampling_params.repetition_penalty != 1.0
):
options["repeat_penalty"] = request.sampling_params.repetition_penalty
return options
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.map_to_provider_model(request.model)
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
if not request.stream:
r = await self.client.chat(
model=ollama_model,
messages=self._messages_to_ollama_messages(messages),
stream=False,
options=options,
)
stop_reason = None
if r["done"]:
if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r["message"]["content"], stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
stream = await self.client.chat(
model=ollama_model,
messages=self._messages_to_ollama_messages(messages),
stream=True,
options=options,
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
if chunk["done"]:
if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk["message"]["content"]
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -1,260 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
logger = logging.getLogger(__name__)
class _HfAdapter(Inference, RoutableProvider):
client: AsyncInferenceClient
max_tokens: int
model_id: str
def __init__(self) -> None:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def get_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
input_tokens = len(model_input.tokens)
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
print(f"Calculated max_new_tokens: {max_new_tokens}")
options = self.get_chat_options(request)
if not request.stream:
response = await self.client.text_generation(
prompt=prompt,
stream=False,
details=True,
max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
)
stop_reason = None
if response.details.finish_reason:
if response.details.finish_reason in ["stop", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif response.details.finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
response.generated_text,
stop_reason,
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
tokens = []
async for response in await self.client.text_generation(
prompt=prompt,
stream=True,
details=True,
max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
):
token_result = response.token
buffer += token_result.text
tokens.append(token_result.id)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# parse tool calls and report errors
message = self.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(
model=config.model_id, token=config.api_token
)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
class InferenceEndpointAdapter(_HfAdapter):
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
# Get the inference endpoint details
api = HfApi(token=config.api_token)
endpoint = api.get_inference_endpoint(config.endpoint_name)
# Wait for the endpoint to be ready (if not already)
endpoint.wait(timeout=60)
# Initialize the adapter
self.client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
)

View file

@ -1,265 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from .config import TogetherImplConfig
TOGETHER_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
}
class TogetherInferenceAdapter(
Inference, NeedsRequestProviderData, RoutableProviderForModels
):
def __init__(self, config: TogetherImplConfig) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> Together:
return Together(api_key=self.config.api_key)
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_together_messages(self, messages: list[Message]) -> list:
together_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
together_messages.append({"role": role, "content": message.content})
return together_messages
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
client = Together(api_key=together_api_key)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
# accumulate sampling params and other options to pass to together
options = self.get_together_chat_options(request)
together_model = self.map_to_provider_model(request.model)
messages = augment_messages_for_tools(request)
if not request.stream:
# TODO: might need to add back an async here
r = client.chat.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if (
r.choices[0].finish_reason == "stop"
or r.choices[0].finish_reason == "eos"
):
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
for chunk in client.chat.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
stream=True,
**options,
):
if finish_reason := chunk.choices[0].finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].delta.content
if text is None:
continue
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -1,8 +0,0 @@
from .config import WeaviateConfig
async def get_adapter_impl(config: WeaviateConfig, _deps):
from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -1,18 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
class WeaviateRequestProviderData(BaseModel):
# if there _is_ provider data, it must specify the API KEY
# if you want it to be optional, use Optional[str]
weaviate_api_key: str
weaviate_cluster_url: str
@json_schema_type
class WeaviateConfig(BaseModel):
collection: str = Field(default="MemoryBank")

View file

@ -1,192 +0,0 @@
import json
import uuid
from typing import List, Optional, Dict, Any
from numpy.typing import NDArray
import weaviate
import weaviate.classes as wvc
from weaviate.classes.init import Auth
from llama_stack.apis.memory import *
from llama_stack.distribution.request_headers import get_request_provider_data
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import WeaviateConfig, WeaviateRequestProviderData
class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection: str):
self.client = client
self.collection = collection
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
data_objects = []
for i, chunk in enumerate(chunks):
data_objects.append(wvc.data.DataObject(
properties={
"chunk_content": chunk,
},
vector = embeddings[i].tolist()
))
# Inserting chunks into a prespecified Weaviate collection
assert self.collection is not None, "Collection name must be specified"
my_collection = self.client.collections.get(self.collection)
await my_collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
assert self.collection is not None, "Collection name must be specified"
my_collection = self.client.collections.get(self.collection)
results = my_collection.query.near_vector(
near_vector = embedding.tolist(),
limit = k,
return_meta_data = wvc.query.MetadataQuery(distance=True)
)
chunks = []
scores = []
for doc in results.objects:
try:
chunk = doc.properties["chunk_content"]
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
except Exception as e:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {e}")
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class WeaviateMemoryAdapter(Memory):
def __init__(self, config: WeaviateConfig) -> None:
self.config = config
self.client = None
self.cache = {}
def _get_client(self) -> weaviate.Client:
request_provider_data = get_request_provider_data()
if request_provider_data is not None:
assert isinstance(request_provider_data, WeaviateRequestProviderData)
# Connect to Weaviate Cloud
return weaviate.connect_to_weaviate_cloud(
cluster_url = request_provider_data.weaviate_cluster_url,
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
)
async def initialize(self) -> None:
try:
self.client = self._get_client()
# Create collection if it doesn't exist
if not self.client.collections.exists(self.config.collection):
self.client.collections.create(
name = self.config.collection,
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
]
)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to Weaviate server") from e
async def shutdown(self) -> None:
self.client = self._get_client()
if self.client:
self.client.close()
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
self.client = self._get_client()
# Store the bank as a new collection in Weaviate
self.client.collections.create(
name=bank_id
)
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(cleint = self.client, collection = bank_id),
)
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
self.client = self._get_client()
if bank_id in self.cache:
return self.cache[bank_id]
collections = await self.client.collections.list_all().keys()
for collection in collections:
if collection == bank_id:
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(self.client, collection),
)
self.cache[bank_id] = index
return index
return None
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -1,120 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import traceback
from typing import Any, Dict, List
import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
SUPPORTED_SHIELD_TYPES = [
"bedrock_guardrail",
]
class BedrockSafetyAdapter(Safety, RoutableProvider):
def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config
async def initialize(self) -> None:
try:
print(f"initializing with profile --- > {self.config}")
self.boto_client = boto3.Session(
profile_name=self.config.aws_profile
).client("bedrock-runtime")
except Exception as e:
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{
"text": {
"text": "Is the AB503 Product a better investment than the S&P 500?"
}
}
]```
However the incoming messages are of this type UserMessage(content=....) coming from
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
try:
logger.debug(f"run_shield::{params}::messages={messages}")
if "guardrailIdentifier" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get("guardrailIdentifier"),
guardrailVersion=params.get("guardrailVersion"),
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
logger.debug(f"run_shield:: response: {response}::")
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
except Exception:
error_str = traceback.format_exc()
logger.error(
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
)
return None

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
class BedrockSafetyConfig(BaseModel):
"""Configuration information for a guardrail that you want to use in the request."""
aws_profile: str = Field(
default="default",
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
)

View file

@ -1,26 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
class TogetherProviderDataValidator(BaseModel):
together_api_key: str
@json_schema_type
class TogetherSafetyConfig(BaseModel):
url: str = Field(
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
)
api_key: Optional[str] = Field(
default=None,
description="The Together AI API Key (default for the distribution, if any)",
)

Some files were not shown because too many files have changed in this diff Show more