mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-19 02:19:38 +00:00
Merge branch 'main' into clarifai-inference-provider
This commit is contained in:
commit
4b9085d312
536 changed files with 34661 additions and 12116 deletions
|
|
@ -6,7 +6,17 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -44,6 +54,7 @@ class ToolDefinitionCommon(BaseModel):
|
|||
class SearchEngineType(Enum):
|
||||
bing = "bing"
|
||||
brave = "brave"
|
||||
tavily = "tavily"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -396,6 +407,8 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStreamChunk(BaseModel):
|
||||
"""streamed agent turn completion response."""
|
||||
|
||||
event: AgentTurnResponseEvent
|
||||
|
||||
|
||||
|
|
@ -404,6 +417,7 @@ class AgentStepResponse(BaseModel):
|
|||
step: Step
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agents(Protocol):
|
||||
@webmethod(route="/agents/create")
|
||||
async def create_agent(
|
||||
|
|
@ -424,18 +438,16 @@ class Agents(Protocol):
|
|||
],
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AgentTurnResponseStreamChunk: ...
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(route="/agents/turn/get")
|
||||
async def get_agents_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
turn_id: str,
|
||||
self, agent_id: str, session_id: str, turn_id: str
|
||||
) -> Turn: ...
|
||||
|
||||
@webmethod(route="/agents/step/get")
|
||||
async def get_agents_step(
|
||||
self, agent_id: str, turn_id: str, step_id: str
|
||||
self, agent_id: str, session_id: str, turn_id: str, step_id: str
|
||||
) -> AgentStepResponse: ...
|
||||
|
||||
@webmethod(route="/agents/session/create")
|
||||
|
|
|
|||
|
|
@ -7,22 +7,26 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
from .agents import * # noqa: F403
|
||||
import logging
|
||||
|
||||
from .event_logger import EventLogger
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
|
|
@ -70,6 +74,14 @@ class AgentsClient(Agents):
|
|||
async def create_agent_turn(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
if request.stream:
|
||||
return self._stream_agent_turn(request)
|
||||
else:
|
||||
return await self._nonstream_agent_turn(request)
|
||||
|
||||
async def _stream_agent_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
|
|
@ -85,13 +97,15 @@ class AgentsClient(Agents):
|
|||
try:
|
||||
jdata = json.loads(data)
|
||||
if "error" in jdata:
|
||||
cprint(data, "red")
|
||||
log.error(data)
|
||||
continue
|
||||
|
||||
yield AgentTurnResponseStreamChunk(**jdata)
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
log.error(f"Error with parsing or validation: {e}")
|
||||
|
||||
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
|
||||
raise NotImplementedError("Non-streaming not implemented yet")
|
||||
|
||||
|
||||
async def _run_agent(
|
||||
|
|
@ -114,8 +128,8 @@ async def _run_agent(
|
|||
)
|
||||
|
||||
for content in user_prompts:
|
||||
cprint(f"User> {content}", color="white", attrs=["bold"])
|
||||
iterator = api.create_agent_turn(
|
||||
log.info(f"User> {content}", color="white", attrs=["bold"])
|
||||
iterator = await api.create_agent_turn(
|
||||
AgentTurnCreateRequest(
|
||||
agent_id=create_response.agent_id,
|
||||
session_id=session_response.session_id,
|
||||
|
|
@ -127,13 +141,12 @@ async def _run_agent(
|
|||
)
|
||||
)
|
||||
|
||||
async for event, log in EventLogger().log(iterator):
|
||||
if log is not None:
|
||||
log.print()
|
||||
async for event, logger in EventLogger().log(iterator):
|
||||
if logger is not None:
|
||||
log.info(logger)
|
||||
|
||||
|
||||
async def run_llama_3_1(host: str, port: int):
|
||||
model = "Llama3.1-8B-Instruct"
|
||||
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
|
|
@ -173,8 +186,7 @@ async def run_llama_3_1(host: str, port: int):
|
|||
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
|
||||
|
||||
|
||||
async def run_llama_3_2_rag(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
urls = [
|
||||
|
|
@ -215,8 +227,7 @@ async def run_llama_3_2_rag(host: str, port: int):
|
|||
)
|
||||
|
||||
|
||||
async def run_llama_3_2(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
# zero shot tools for llama3.2 text models
|
||||
|
|
@ -262,7 +273,7 @@ async def run_llama_3_2(host: str, port: int):
|
|||
)
|
||||
|
||||
|
||||
def main(host: str, port: int, run_type: str):
|
||||
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
|
||||
assert run_type in [
|
||||
"tools_llama_3_1",
|
||||
"tools_llama_3_2",
|
||||
|
|
@ -274,7 +285,10 @@ def main(host: str, port: int, run_type: str):
|
|||
"tools_llama_3_2": run_llama_3_2,
|
||||
"rag_llama_3_2": run_llama_3_2_rag,
|
||||
}
|
||||
asyncio.run(fn[run_type](host, port))
|
||||
args = [host, port]
|
||||
if model is not None:
|
||||
args.append(model)
|
||||
asyncio.run(fn[run_type](*args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -180,5 +180,5 @@ class EventLogger:
|
|||
color="cyan",
|
||||
)
|
||||
|
||||
preivous_event_type = event_type
|
||||
previous_event_type = event_type
|
||||
previous_step_type = step_type
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -47,8 +47,9 @@ class BatchChatCompletionResponse(BaseModel):
|
|||
completion_message_batch: List[CompletionMessage]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BatchInference(Protocol):
|
||||
@webmethod(route="/batch_inference/completion")
|
||||
@webmethod(route="/batch-inference/completion")
|
||||
async def batch_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -57,7 +58,7 @@ class BatchInference(Protocol):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse: ...
|
||||
|
||||
@webmethod(route="/batch_inference/chat_completion")
|
||||
@webmethod(route="/batch-inference/chat-completion")
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
20
llama_stack/apis/common/job_types.py
Normal file
20
llama_stack/apis/common/job_types.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class JobStatus(Enum):
|
||||
completed = "completed"
|
||||
in_progress = "in_progress"
|
||||
83
llama_stack/apis/common/type_system.py
Normal file
83
llama_stack/apis/common/type_system.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class StringType(BaseModel):
|
||||
type: Literal["string"] = "string"
|
||||
|
||||
|
||||
class NumberType(BaseModel):
|
||||
type: Literal["number"] = "number"
|
||||
|
||||
|
||||
class BooleanType(BaseModel):
|
||||
type: Literal["boolean"] = "boolean"
|
||||
|
||||
|
||||
class ArrayType(BaseModel):
|
||||
type: Literal["array"] = "array"
|
||||
|
||||
|
||||
class ObjectType(BaseModel):
|
||||
type: Literal["object"] = "object"
|
||||
|
||||
|
||||
class JsonType(BaseModel):
|
||||
type: Literal["json"] = "json"
|
||||
|
||||
|
||||
class UnionType(BaseModel):
|
||||
type: Literal["union"] = "union"
|
||||
|
||||
|
||||
class ChatCompletionInputType(BaseModel):
|
||||
# expects List[Message] for messages
|
||||
type: Literal["chat_completion_input"] = "chat_completion_input"
|
||||
|
||||
|
||||
class CompletionInputType(BaseModel):
|
||||
# expects InterleavedTextMedia for content
|
||||
type: Literal["completion_input"] = "completion_input"
|
||||
|
||||
|
||||
class AgentTurnInputType(BaseModel):
|
||||
# expects List[Message] for messages (may also include attachments?)
|
||||
type: Literal["agent_turn_input"] = "agent_turn_input"
|
||||
|
||||
|
||||
ParamType = Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
NumberType,
|
||||
BooleanType,
|
||||
ArrayType,
|
||||
ObjectType,
|
||||
JsonType,
|
||||
UnionType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
# TODO: recursive definition of ParamType in these containers
|
||||
# will cause infinite recursion in OpenAPI generation script
|
||||
# since we are going with ChatCompletionInputType and CompletionInputType
|
||||
# we don't need to worry about ArrayType/ObjectType/UnionType for now
|
||||
# ArrayType.model_rebuild()
|
||||
# ObjectType.model_rebuild()
|
||||
# UnionType.model_rebuild()
|
||||
|
||||
|
||||
# class CustomType(BaseModel):
|
||||
# type: Literal["custom"] = "custom"
|
||||
# validator_class: str
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TrainEvalDatasetColumnType(Enum):
|
||||
dialog = "dialog"
|
||||
text = "text"
|
||||
media = "media"
|
||||
number = "number"
|
||||
json = "json"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TrainEvalDataset(BaseModel):
|
||||
"""Dataset to be used for training or evaluating language models."""
|
||||
|
||||
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
|
||||
|
||||
columns: Dict[str, TrainEvalDatasetColumnType]
|
||||
content_url: URL
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CreateDatasetRequest(BaseModel):
|
||||
"""Request to create a dataset."""
|
||||
|
||||
uuid: str
|
||||
dataset: TrainEvalDataset
|
||||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets/create")
|
||||
def create_dataset(
|
||||
self,
|
||||
uuid: str,
|
||||
dataset: TrainEvalDataset,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/datasets/get")
|
||||
def get_dataset(
|
||||
self,
|
||||
dataset_uuid: str,
|
||||
) -> TrainEvalDataset: ...
|
||||
|
||||
@webmethod(route="/datasets/delete")
|
||||
def delete_dataset(
|
||||
self,
|
||||
dataset_uuid: str,
|
||||
) -> None: ...
|
||||
|
|
@ -4,4 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .reward_scoring import * # noqa: F401 F403
|
||||
from .datasetio import * # noqa: F401 F403
|
||||
103
llama_stack/apis/datasetio/client.py
Normal file
103
llama_stack/apis/datasetio/client.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasets.client import DatasetsClient
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
|
||||
|
||||
|
||||
class DatasetIOClient(DatasetIO):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/datasetio/get_rows_paginated",
|
||||
params={
|
||||
"dataset_id": dataset_id,
|
||||
"rows_in_page": rows_in_page,
|
||||
"page_token": page_token,
|
||||
"filter_condition": filter_condition,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return PaginatedRowsResult(**response.json())
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = DatasetsClient(f"http://{host}:{port}")
|
||||
|
||||
# register dataset
|
||||
test_file = (
|
||||
Path(os.path.abspath(__file__)).parent.parent.parent
|
||||
/ "providers/tests/datasetio/test_dataset.csv"
|
||||
)
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
response = await client.register_dataset(
|
||||
DatasetDefWithProvider(
|
||||
identifier="test-dataset",
|
||||
provider_id="meta0",
|
||||
url=URL(
|
||||
uri=test_url,
|
||||
),
|
||||
dataset_schema={
|
||||
"generated_answer": StringType(),
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# list datasets
|
||||
list_dataset = await client.list_datasets()
|
||||
cprint(list_dataset, "blue")
|
||||
|
||||
# datsetio client to get the rows
|
||||
datasetio_client = DatasetIOClient(f"http://{host}:{port}")
|
||||
response = await datasetio_client.get_rows_paginated(
|
||||
dataset_id="test-dataset",
|
||||
rows_in_page=4,
|
||||
page_token=None,
|
||||
filter_condition=None,
|
||||
)
|
||||
cprint(f"Returned {len(response.rows)} rows \n {response}", "green")
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_main(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
39
llama_stack/apis/datasetio/datasetio.py
Normal file
39
llama_stack/apis/datasetio/datasetio.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PaginatedRowsResult(BaseModel):
|
||||
# the rows obey the DatasetSchema for the given dataset
|
||||
rows: List[Dict[str, Any]]
|
||||
total_count: int
|
||||
next_page_token: Optional[str] = None
|
||||
|
||||
|
||||
class DatasetStore(Protocol):
|
||||
def get_dataset(self, dataset_id: str) -> Dataset: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class DatasetIO(Protocol):
|
||||
# keeping for aligning with inference/safety, but this is not used
|
||||
dataset_store: DatasetStore
|
||||
|
||||
@webmethod(route="/datasetio/get-rows-paginated", method="GET")
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult: ...
|
||||
7
llama_stack/apis/datasets/__init__.py
Normal file
7
llama_stack/apis/datasets/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .datasets import * # noqa: F401 F403
|
||||
116
llama_stack/apis/datasets/client.py
Normal file
116
llama_stack/apis/datasets/client.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from .datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
|
||||
|
||||
|
||||
class DatasetsClient(Datasets):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_def: DatasetDefWithProvider,
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/datasets/register",
|
||||
json={
|
||||
"dataset_def": json.loads(dataset_def.json()),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return
|
||||
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_identifier: str,
|
||||
) -> Optional[DatasetDefWithProvider]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/datasets/get",
|
||||
params={
|
||||
"dataset_identifier": dataset_identifier,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return DatasetDefWithProvider(**response.json())
|
||||
|
||||
async def list_datasets(self) -> List[DatasetDefWithProvider]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/datasets/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return [DatasetDefWithProvider(**x) for x in response.json()]
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = DatasetsClient(f"http://{host}:{port}")
|
||||
|
||||
# register dataset
|
||||
test_file = (
|
||||
Path(os.path.abspath(__file__)).parent.parent.parent
|
||||
/ "providers/tests/datasetio/test_dataset.csv"
|
||||
)
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
response = await client.register_dataset(
|
||||
DatasetDefWithProvider(
|
||||
identifier="test-dataset",
|
||||
provider_id="meta0",
|
||||
url=URL(
|
||||
uri=test_url,
|
||||
),
|
||||
dataset_schema={
|
||||
"generated_answer": StringType(),
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# list datasets
|
||||
list_dataset = await client.list_datasets()
|
||||
cprint(list_dataset, "blue")
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_main(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
66
llama_stack/apis/datasets/datasets.py
Normal file
66
llama_stack/apis/datasets/datasets.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
class CommonDatasetFields(BaseModel):
|
||||
dataset_schema: Dict[str, ParamType]
|
||||
url: URL
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this dataset",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Dataset(CommonDatasetFields, Resource):
|
||||
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_dataset_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||
dataset_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_dataset_id: Optional[str] = None
|
||||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets/register", method="POST")
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_schema: Dict[str, ParamType],
|
||||
url: URL,
|
||||
provider_dataset_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/datasets/get", method="GET")
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> Optional[Dataset]: ...
|
||||
|
||||
@webmethod(route="/datasets/list", method="GET")
|
||||
async def list_datasets(self) -> List[Dataset]: ...
|
||||
|
|
@ -4,4 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .evals import * # noqa: F401 F403
|
||||
from .eval import * # noqa: F401 F403
|
||||
100
llama_stack/apis/eval/eval.py
Normal file
100
llama_stack/apis/eval/eval.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Optional, Protocol, Union
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelCandidate(BaseModel):
|
||||
type: Literal["model"] = "model"
|
||||
model: str
|
||||
sampling_params: SamplingParams
|
||||
system_message: Optional[SystemMessage] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentCandidate(BaseModel):
|
||||
type: Literal["agent"] = "agent"
|
||||
config: AgentConfig
|
||||
|
||||
|
||||
EvalCandidate = Annotated[
|
||||
Union[ModelCandidate, AgentCandidate], Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BenchmarkEvalTaskConfig(BaseModel):
|
||||
type: Literal["benchmark"] = "benchmark"
|
||||
eval_candidate: EvalCandidate
|
||||
num_examples: Optional[int] = Field(
|
||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AppEvalTaskConfig(BaseModel):
|
||||
type: Literal["app"] = "app"
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
default_factory=dict,
|
||||
)
|
||||
num_examples: Optional[int] = Field(
|
||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||
default=None,
|
||||
)
|
||||
# we could optinally add any specific dataset config here
|
||||
|
||||
|
||||
EvalTaskConfig = Annotated[
|
||||
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateResponse(BaseModel):
|
||||
generations: List[Dict[str, Any]]
|
||||
# each key in the dict is a scoring function name
|
||||
scores: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class Eval(Protocol):
|
||||
@webmethod(route="/eval/run-eval", method="POST")
|
||||
async def run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: EvalTaskConfig,
|
||||
) -> Job: ...
|
||||
|
||||
@webmethod(route="/eval/evaluate-rows", method="POST")
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: EvalTaskConfig,
|
||||
) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/job/status", method="GET")
|
||||
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||
|
||||
@webmethod(route="/eval/job/cancel", method="POST")
|
||||
async def job_cancel(self, task_id: str, job_id: str) -> None: ...
|
||||
|
||||
@webmethod(route="/eval/job/result", method="GET")
|
||||
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...
|
||||
7
llama_stack/apis/eval_tasks/__init__.py
Normal file
7
llama_stack/apis/eval_tasks/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .eval_tasks import * # noqa: F401 F403
|
||||
60
llama_stack/apis/eval_tasks/eval_tasks.py
Normal file
60
llama_stack/apis/eval_tasks/eval_tasks.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
class CommonEvalTaskFields(BaseModel):
|
||||
dataset_id: str
|
||||
scoring_functions: List[str]
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Metadata for this evaluation task",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvalTask(CommonEvalTaskFields, Resource):
|
||||
type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value
|
||||
|
||||
@property
|
||||
def eval_task_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_eval_task_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
|
||||
eval_task_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_eval_task_id: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EvalTasks(Protocol):
|
||||
@webmethod(route="/eval-tasks/list", method="GET")
|
||||
async def list_eval_tasks(self) -> List[EvalTask]: ...
|
||||
|
||||
@webmethod(route="/eval-tasks/get", method="GET")
|
||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
|
||||
|
||||
@webmethod(route="/eval-tasks/register", method="POST")
|
||||
async def register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_eval_task_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Protocol
|
||||
|
||||
from llama_models.schema_utils import webmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.dataset import * # noqa: F403
|
||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||
|
||||
|
||||
class TextGenerationMetric(Enum):
|
||||
perplexity = "perplexity"
|
||||
rouge = "rouge"
|
||||
bleu = "bleu"
|
||||
|
||||
|
||||
class QuestionAnsweringMetric(Enum):
|
||||
em = "em"
|
||||
f1 = "f1"
|
||||
|
||||
|
||||
class SummarizationMetric(Enum):
|
||||
rouge = "rouge"
|
||||
bleu = "bleu"
|
||||
|
||||
|
||||
class EvaluationJob(BaseModel):
|
||||
job_uuid: str
|
||||
|
||||
|
||||
class EvaluationJobLogStream(BaseModel):
|
||||
job_uuid: str
|
||||
|
||||
|
||||
class EvaluateTaskRequestCommon(BaseModel):
|
||||
job_uuid: str
|
||||
dataset: TrainEvalDataset
|
||||
|
||||
checkpoint: Checkpoint
|
||||
|
||||
# generation params
|
||||
sampling_params: SamplingParams = SamplingParams()
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
|
||||
"""Request to evaluate text generation."""
|
||||
|
||||
metrics: List[TextGenerationMetric]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
|
||||
"""Request to evaluate question answering."""
|
||||
|
||||
metrics: List[QuestionAnsweringMetric]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
|
||||
"""Request to evaluate summarization."""
|
||||
|
||||
metrics: List[SummarizationMetric]
|
||||
|
||||
|
||||
class EvaluationJobStatusResponse(BaseModel):
|
||||
job_uuid: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluationJobArtifactsResponse(BaseModel):
|
||||
"""Artifacts of a evaluation job."""
|
||||
|
||||
job_uuid: str
|
||||
|
||||
|
||||
class Evaluations(Protocol):
|
||||
@webmethod(route="/evaluate/text_generation/")
|
||||
def evaluate_text_generation(
|
||||
self,
|
||||
metrics: List[TextGenerationMetric],
|
||||
) -> EvaluationJob: ...
|
||||
|
||||
@webmethod(route="/evaluate/question_answering/")
|
||||
def evaluate_question_answering(
|
||||
self,
|
||||
metrics: List[QuestionAnsweringMetric],
|
||||
) -> EvaluationJob: ...
|
||||
|
||||
@webmethod(route="/evaluate/summarization/")
|
||||
def evaluate_summarization(
|
||||
self,
|
||||
metrics: List[SummarizationMetric],
|
||||
) -> EvaluationJob: ...
|
||||
|
||||
@webmethod(route="/evaluate/jobs")
|
||||
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/status")
|
||||
def get_evaluation_job_status(
|
||||
self, job_uuid: str
|
||||
) -> EvaluationJobStatusResponse: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/evaluate/job/logs")
|
||||
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/cancel")
|
||||
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/artifacts")
|
||||
def get_evaluation_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> EvaluationJobArtifactsResponse: ...
|
||||
|
|
@ -53,6 +53,7 @@ class InferenceClient(Inference):
|
|||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
|
@ -63,9 +64,33 @@ class InferenceClient(Inference):
|
|||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
j = response.json()
|
||||
return ChatCompletionResponse(**j)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
|
|
@ -77,7 +102,8 @@ class InferenceClient(Inference):
|
|||
if response.status_code != 200:
|
||||
content = await response.aread()
|
||||
cprint(
|
||||
f"Error: HTTP {response.status_code} {content.decode()}", "red"
|
||||
f"Error: HTTP {response.status_code} {content.decode()}",
|
||||
"red",
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -85,16 +111,11 @@ class InferenceClient(Inference):
|
|||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
if request.stream:
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
**json.loads(data)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponse(**json.loads(data))
|
||||
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
|
|
@ -120,7 +141,8 @@ async def run_main(
|
|||
else:
|
||||
logprobs_config = None
|
||||
|
||||
iterator = client.chat_completion(
|
||||
assert stream, "Non streaming not supported here"
|
||||
iterator = await client.chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
|
|
@ -150,7 +172,7 @@ async def run_mm_main(
|
|||
],
|
||||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
iterator = client.chat_completion(
|
||||
iterator = await client.chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,15 @@
|
|||
|
||||
from enum import Enum
|
||||
|
||||
from typing import List, Literal, Optional, Protocol, Union
|
||||
from typing import (
|
||||
AsyncIterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -14,6 +22,7 @@ from pydantic import BaseModel, Field
|
|||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
|
|
@ -24,6 +33,7 @@ class LogProbConfig(BaseModel):
|
|||
class QuantizationType(Enum):
|
||||
bf16 = "bf16"
|
||||
fp8 = "fp8"
|
||||
int4 = "int4"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -36,8 +46,14 @@ class Bf16QuantizationConfig(BaseModel):
|
|||
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Int4QuantizationConfig(BaseModel):
|
||||
type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value
|
||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
QuantizationConfig = Annotated[
|
||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
|
||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
|
@ -73,11 +89,35 @@ class ChatCompletionResponseEvent(BaseModel):
|
|||
stop_reason: Optional[StopReason] = None
|
||||
|
||||
|
||||
class ResponseFormatType(Enum):
|
||||
json_schema = "json_schema"
|
||||
grammar = "grammar"
|
||||
|
||||
|
||||
class JsonSchemaResponseFormat(BaseModel):
|
||||
type: Literal[ResponseFormatType.json_schema.value] = (
|
||||
ResponseFormatType.json_schema.value
|
||||
)
|
||||
json_schema: Dict[str, Any]
|
||||
|
||||
|
||||
class GrammarResponseFormat(BaseModel):
|
||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
||||
bnf: Dict[str, Any]
|
||||
|
||||
|
||||
ResponseFormat = Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
content: InterleavedTextMedia
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
|
@ -87,7 +127,8 @@ class CompletionRequest(BaseModel):
|
|||
class CompletionResponse(BaseModel):
|
||||
"""Completion response."""
|
||||
|
||||
completion_message: CompletionMessage
|
||||
content: str
|
||||
stop_reason: StopReason
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
|
||||
|
||||
|
|
@ -105,6 +146,7 @@ class BatchCompletionRequest(BaseModel):
|
|||
model: str
|
||||
content_batch: List[InterleavedTextMedia]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
||||
|
||||
|
|
@ -112,7 +154,7 @@ class BatchCompletionRequest(BaseModel):
|
|||
class BatchCompletionResponse(BaseModel):
|
||||
"""Batch completion response."""
|
||||
|
||||
completion_message_batch: List[CompletionMessage]
|
||||
batch: List[CompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -127,6 +169,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
|
@ -164,7 +207,7 @@ class BatchChatCompletionRequest(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
completion_message_batch: List[CompletionMessage]
|
||||
batch: List[ChatCompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -172,34 +215,45 @@ class EmbeddingsResponse(BaseModel):
|
|||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
class ModelStore(Protocol):
|
||||
def get_model(self, identifier: str) -> Model: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Inference(Protocol):
|
||||
model_store: ModelStore
|
||||
|
||||
@webmethod(route="/inference/completion")
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(route="/inference/chat_completion")
|
||||
@webmethod(route="/inference/chat-completion")
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]: ...
|
||||
|
||||
@webmethod(route="/inference/embeddings")
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse: ...
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List, Protocol
|
||||
from typing import Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -12,15 +12,15 @@ from pydantic import BaseModel
|
|||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
description: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
route: str
|
||||
method: str
|
||||
providers: List[str]
|
||||
provider_types: List[str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -29,6 +29,7 @@ class HealthInfo(BaseModel):
|
|||
# TODO: add a provider level status
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/providers/list", method="GET")
|
||||
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -13,11 +12,11 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks.client import MemoryBanksClient
|
||||
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
||||
|
||||
|
||||
|
|
@ -35,45 +34,6 @@ class MemoryClient(Memory):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(
|
||||
f"{self.base_url}/memory/get",
|
||||
params={
|
||||
"bank_id": bank_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if not d:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(
|
||||
f"{self.base_url}/memory/create",
|
||||
json={
|
||||
"name": name,
|
||||
"config": config.dict(),
|
||||
"url": url,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
if not d:
|
||||
return None
|
||||
return MemoryBank(**d)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
|
|
@ -113,23 +73,28 @@ class MemoryClient(Memory):
|
|||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
client = MemoryClient(f"http://{host}:{port}")
|
||||
banks_client = MemoryBanksClient(f"http://{host}:{port}")
|
||||
|
||||
# create a memory bank
|
||||
bank = await client.create_memory_bank(
|
||||
name="test_bank",
|
||||
config=VectorMemoryBankConfig(
|
||||
bank_id="test_bank",
|
||||
bank = VectorMemoryBank(
|
||||
identifier="test_bank",
|
||||
provider_id="",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
await banks_client.register_memory_bank(
|
||||
bank.identifier,
|
||||
VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
provider_resource_id=bank.identifier,
|
||||
)
|
||||
cprint(json.dumps(bank.dict(), indent=4), "green")
|
||||
|
||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||
retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
|
||||
assert retrieved_bank is not None
|
||||
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
|
||||
assert retrieved_bank.embedding_model == "all-MiniLM-L6-v2"
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
|
|
@ -160,15 +125,17 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
for i, path in enumerate(files)
|
||||
]
|
||||
|
||||
client = MemoryClient(f"http://{host}:{port}")
|
||||
|
||||
# insert some documents
|
||||
await client.insert_documents(
|
||||
bank_id=bank.bank_id,
|
||||
bank_id=bank.identifier,
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
# query the documents
|
||||
response = await client.query_documents(
|
||||
bank_id=bank.bank_id,
|
||||
bank_id=bank.identifier,
|
||||
query=[
|
||||
"How do I use Lora?",
|
||||
],
|
||||
|
|
@ -178,7 +145,7 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
print(f"Chunk:\n========\n{chunk}\n========\n")
|
||||
|
||||
response = await client.query_documents(
|
||||
bank_id=bank.bank_id,
|
||||
bank_id=bank.identifier,
|
||||
query=[
|
||||
"Tell me more about llama3 and torchtune",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -26,44 +26,6 @@ class MemoryBankDocument(BaseModel):
|
|||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBankType(Enum):
|
||||
vector = "vector"
|
||||
keyvalue = "keyvalue"
|
||||
keyword = "keyword"
|
||||
graph = "graph"
|
||||
|
||||
|
||||
class VectorMemoryBankConfig(BaseModel):
|
||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
embedding_model: str
|
||||
chunk_size_in_tokens: int
|
||||
overlap_size_in_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class KeyValueMemoryBankConfig(BaseModel):
|
||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||
|
||||
|
||||
class KeywordMemoryBankConfig(BaseModel):
|
||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||
|
||||
|
||||
class GraphMemoryBankConfig(BaseModel):
|
||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankConfig,
|
||||
KeyValueMemoryBankConfig,
|
||||
KeywordMemoryBankConfig,
|
||||
GraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
content: InterleavedTextMedia
|
||||
token_count: int
|
||||
|
|
@ -76,45 +38,13 @@ class QueryDocumentsResponse(BaseModel):
|
|||
scores: List[float]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryAPI(Protocol):
|
||||
@webmethod(route="/query_documents")
|
||||
def query_documents(
|
||||
self,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBank(BaseModel):
|
||||
bank_id: str
|
||||
name: str
|
||||
config: MemoryBankConfig
|
||||
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
|
||||
url: Optional[URL] = None
|
||||
class MemoryBankStore(Protocol):
|
||||
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Memory(Protocol):
|
||||
@webmethod(route="/memory/create")
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank: ...
|
||||
|
||||
@webmethod(route="/memory/list", method="GET")
|
||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory/get", method="GET")
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory/drop", method="DELETE")
|
||||
async def drop_memory_bank(
|
||||
self,
|
||||
bank_id: str,
|
||||
) -> str: ...
|
||||
memory_bank_store: MemoryBankStore
|
||||
|
||||
# this will just block now until documents are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
|
|
@ -126,13 +56,6 @@ class Memory(Protocol):
|
|||
ttl_seconds: Optional[int] = None,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/update")
|
||||
async def update_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/memory/query")
|
||||
async def query_documents(
|
||||
self,
|
||||
|
|
@ -140,17 +63,3 @@ class Memory(Protocol):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse: ...
|
||||
|
||||
@webmethod(route="/memory/documents/get", method="GET")
|
||||
async def get_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_ids: List[str],
|
||||
) -> List[MemoryBankDocument]: ...
|
||||
|
||||
@webmethod(route="/memory/documents/delete", method="DELETE")
|
||||
async def delete_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
document_ids: List[str],
|
||||
) -> None: ...
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
|
@ -15,6 +15,27 @@ from termcolor import cprint
|
|||
from .memory_banks import * # noqa: F403
|
||||
|
||||
|
||||
def deserialize_memory_bank_def(
|
||||
j: Optional[Dict[str, Any]]
|
||||
) -> MemoryBankDefWithProvider:
|
||||
if j is None:
|
||||
return None
|
||||
|
||||
if "type" not in j:
|
||||
raise ValueError("Memory bank type not specified")
|
||||
type = j["type"]
|
||||
if type == MemoryBankType.vector.value:
|
||||
return VectorMemoryBank(**j)
|
||||
elif type == MemoryBankType.keyvalue.value:
|
||||
return KeyValueMemoryBank(**j)
|
||||
elif type == MemoryBankType.keyword.value:
|
||||
return KeywordMemoryBank(**j)
|
||||
elif type == MemoryBankType.graph.value:
|
||||
return GraphMemoryBank(**j)
|
||||
else:
|
||||
raise ValueError(f"Unknown memory bank type: {type}")
|
||||
|
||||
|
||||
class MemoryBanksClient(MemoryBanks):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
|
@ -25,37 +46,71 @@ class MemoryBanksClient(MemoryBanks):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/memory_banks/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [MemoryBankSpec(**x) for x in response.json()]
|
||||
return [deserialize_memory_bank_def(x) for x in response.json()]
|
||||
|
||||
async def get_serving_memory_bank(
|
||||
self, bank_type: MemoryBankType
|
||||
) -> Optional[MemoryBankSpec]:
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_resource_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/memory_banks/register",
|
||||
json={
|
||||
"memory_bank_id": memory_bank_id,
|
||||
"provider_resource_id": provider_resource_id,
|
||||
"provider_id": provider_id,
|
||||
"params": params.dict(),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
) -> Optional[MemoryBank]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/memory_banks/get",
|
||||
params={
|
||||
"bank_type": bank_type.value,
|
||||
"memory_bank_id": memory_bank_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
return MemoryBankSpec(**j)
|
||||
return deserialize_memory_bank_def(j)
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
client = MemoryBanksClient(f"http://{host}:{port}")
|
||||
|
||||
response = await client.list_available_memory_banks()
|
||||
response = await client.list_memory_banks()
|
||||
cprint(f"list_memory_banks response={response}", "green")
|
||||
|
||||
# register memory bank for the first time
|
||||
response = await client.register_memory_bank(
|
||||
memory_bank_id="test_bank2",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
cprint(f"register_memory_bank response={response}", "blue")
|
||||
|
||||
# list again after registering
|
||||
response = await client.list_memory_banks()
|
||||
cprint(f"list_memory_banks response={response}", "green")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,29 +4,146 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.memory import MemoryBankType
|
||||
|
||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBankSpec(BaseModel):
|
||||
bank_type: MemoryBankType
|
||||
provider_config: GenericProviderConfig = Field(
|
||||
description="Provider config for the model, including provider_type, and corresponding config. ",
|
||||
class MemoryBankType(Enum):
|
||||
vector = "vector"
|
||||
keyvalue = "keyvalue"
|
||||
keyword = "keyword"
|
||||
graph = "graph"
|
||||
|
||||
|
||||
# define params for each type of memory bank, this leads to a tagged union
|
||||
# accepted as input from the API or from the config.
|
||||
@json_schema_type
|
||||
class VectorMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
embedding_model: str
|
||||
chunk_size_in_tokens: int
|
||||
overlap_size_in_tokens: Optional[int] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class KeyValueMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
||||
MemoryBankType.keyvalue.value
|
||||
)
|
||||
|
||||
|
||||
class MemoryBanks(Protocol):
|
||||
@webmethod(route="/memory_banks/list", method="GET")
|
||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]: ...
|
||||
@json_schema_type
|
||||
class KeywordMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
||||
MemoryBankType.keyword.value
|
||||
)
|
||||
|
||||
@webmethod(route="/memory_banks/get", method="GET")
|
||||
async def get_serving_memory_bank(
|
||||
self, bank_type: MemoryBankType
|
||||
) -> Optional[MemoryBankSpec]: ...
|
||||
|
||||
@json_schema_type
|
||||
class GraphMemoryBankParams(BaseModel):
|
||||
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
|
||||
|
||||
BankParams = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankParams,
|
||||
KeyValueMemoryBankParams,
|
||||
KeywordMemoryBankParams,
|
||||
GraphMemoryBankParams,
|
||||
],
|
||||
Field(discriminator="memory_bank_type"),
|
||||
]
|
||||
|
||||
|
||||
# Some common functionality for memory banks.
|
||||
class MemoryBankResourceMixin(Resource):
|
||||
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
|
||||
|
||||
@property
|
||||
def memory_bank_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_memory_bank_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
embedding_model: str
|
||||
chunk_size_in_tokens: int
|
||||
overlap_size_in_tokens: Optional[int] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class KeyValueMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
|
||||
MemoryBankType.keyvalue.value
|
||||
)
|
||||
|
||||
|
||||
# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
|
||||
@json_schema_type
|
||||
class KeywordMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
|
||||
MemoryBankType.keyword.value
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GraphMemoryBank(MemoryBankResourceMixin):
|
||||
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||
|
||||
|
||||
MemoryBank = Annotated[
|
||||
Union[
|
||||
VectorMemoryBank,
|
||||
KeyValueMemoryBank,
|
||||
KeywordMemoryBank,
|
||||
GraphMemoryBank,
|
||||
],
|
||||
Field(discriminator="memory_bank_type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryBankInput(BaseModel):
|
||||
memory_bank_id: str
|
||||
params: BankParams
|
||||
provider_memory_bank_id: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MemoryBanks(Protocol):
|
||||
@webmethod(route="/memory-banks/list", method="GET")
|
||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory-banks/get", method="GET")
|
||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
|
||||
|
||||
@webmethod(route="/memory-banks/register", method="POST")
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank: ...
|
||||
|
||||
@webmethod(route="/memory-banks/unregister", method="POST")
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
|
|
@ -25,21 +26,32 @@ class ModelsClient(Models):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[ModelServingSpec]:
|
||||
async def list_models(self) -> List[Model]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/models/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [ModelServingSpec(**x) for x in response.json()]
|
||||
return [Model(**x) for x in response.json()]
|
||||
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||
async def register_model(self, model: Model) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/models/register",
|
||||
json={
|
||||
"model": json.loads(model.model_dump_json()),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[Model]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/models/get",
|
||||
params={
|
||||
"core_model_id": core_model_id,
|
||||
"identifier": identifier,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
|
@ -47,7 +59,16 @@ class ModelsClient(Models):
|
|||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
return ModelServingSpec(**j)
|
||||
return Model(**j)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.delete(
|
||||
f"{self.base_url}/models/delete",
|
||||
params={"model_id": model_id},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
|
|
|
|||
|
|
@ -4,29 +4,60 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import Model
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
class CommonModelFields(BaseModel):
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this model",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelServingSpec(BaseModel):
|
||||
llama_model: Model = Field(
|
||||
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).",
|
||||
)
|
||||
provider_config: GenericProviderConfig = Field(
|
||||
description="Provider config for the model, including provider_type, and corresponding config. ",
|
||||
)
|
||||
class Model(CommonModelFields, Resource):
|
||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_model_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ModelInput(CommonModelFields):
|
||||
model_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_model_id: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models/list", method="GET")
|
||||
async def list_models(self) -> List[ModelServingSpec]: ...
|
||||
async def list_models(self) -> List[Model]: ...
|
||||
|
||||
@webmethod(route="/models/get", method="GET")
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...
|
||||
async def get_model(self, identifier: str) -> Optional[Model]: ...
|
||||
|
||||
@webmethod(route="/models/register", method="POST")
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Model: ...
|
||||
|
||||
@webmethod(route="/models/unregister", method="POST")
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.dataset import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||
|
||||
|
||||
|
|
@ -107,8 +107,8 @@ class PostTrainingSFTRequest(BaseModel):
|
|||
job_uuid: str
|
||||
|
||||
model: str
|
||||
dataset: TrainEvalDataset
|
||||
validation_dataset: TrainEvalDataset
|
||||
dataset_id: str
|
||||
validation_dataset_id: str
|
||||
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: Union[
|
||||
|
|
@ -131,8 +131,8 @@ class PostTrainingRLHFRequest(BaseModel):
|
|||
|
||||
finetuned_model: URL
|
||||
|
||||
dataset: TrainEvalDataset
|
||||
validation_dataset: TrainEvalDataset
|
||||
dataset_id: str
|
||||
validation_dataset_id: str
|
||||
|
||||
algorithm: RLHFAlgorithm
|
||||
algorithm_config: Union[DPOAlignmentConfig]
|
||||
|
|
@ -176,13 +176,13 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
|||
|
||||
|
||||
class PostTraining(Protocol):
|
||||
@webmethod(route="/post_training/supervised_fine_tune")
|
||||
@webmethod(route="/post-training/supervised-fine-tune")
|
||||
def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
dataset: TrainEvalDataset,
|
||||
validation_dataset: TrainEvalDataset,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: Union[
|
||||
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
||||
|
|
@ -193,13 +193,13 @@ class PostTraining(Protocol):
|
|||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post_training/preference_optimize")
|
||||
@webmethod(route="/post-training/preference-optimize")
|
||||
def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: URL,
|
||||
dataset: TrainEvalDataset,
|
||||
validation_dataset: TrainEvalDataset,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: RLHFAlgorithm,
|
||||
algorithm_config: Union[DPOAlignmentConfig],
|
||||
optimizer_config: OptimizerConfig,
|
||||
|
|
@ -208,22 +208,22 @@ class PostTraining(Protocol):
|
|||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post_training/jobs")
|
||||
@webmethod(route="/post-training/jobs")
|
||||
def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/post_training/job/logs")
|
||||
@webmethod(route="/post-training/job/logs")
|
||||
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
||||
|
||||
@webmethod(route="/post_training/job/status")
|
||||
@webmethod(route="/post-training/job/status")
|
||||
def get_training_job_status(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobStatusResponse: ...
|
||||
|
||||
@webmethod(route="/post_training/job/cancel")
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/post_training/job/artifacts")
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
def get_training_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobArtifactsResponse: ...
|
||||
|
|
|
|||
39
llama_stack/apis/resource.py
Normal file
39
llama_stack/apis/resource.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ResourceType(Enum):
|
||||
model = "model"
|
||||
shield = "shield"
|
||||
memory_bank = "memory_bank"
|
||||
dataset = "dataset"
|
||||
scoring_function = "scoring_function"
|
||||
eval_task = "eval_task"
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
"""Base class for all Llama Stack resources"""
|
||||
|
||||
identifier: str = Field(
|
||||
description="Unique identifier for this resource in llama stack"
|
||||
)
|
||||
|
||||
provider_resource_id: str = Field(
|
||||
description="Unique identifier for this resource in the provider",
|
||||
default=None,
|
||||
)
|
||||
|
||||
provider_id: str = Field(description="ID of the provider that owns this resource")
|
||||
|
||||
type: ResourceType = Field(
|
||||
description="Type of resource (e.g. 'model', 'shield', 'memory_bank', etc.)"
|
||||
)
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoredMessage(BaseModel):
|
||||
message: Message
|
||||
score: float
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DialogGenerations(BaseModel):
|
||||
dialog: List[Message]
|
||||
sampled_generations: List[Message]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoredDialogGenerations(BaseModel):
|
||||
dialog: List[Message]
|
||||
scored_generations: List[ScoredMessage]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RewardScoringRequest(BaseModel):
|
||||
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
||||
|
||||
dialog_generations: List[DialogGenerations]
|
||||
model: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RewardScoringResponse(BaseModel):
|
||||
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
scored_generations: List[ScoredDialogGenerations]
|
||||
|
||||
|
||||
class RewardScoring(Protocol):
|
||||
@webmethod(route="/reward_scoring/score")
|
||||
def reward_score(
|
||||
self,
|
||||
dialog_generations: List[DialogGenerations],
|
||||
model: str,
|
||||
) -> Union[RewardScoringResponse]: ...
|
||||
|
|
@ -27,7 +27,7 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
|
|||
|
||||
|
||||
def encodable_dict(d: BaseModel):
|
||||
return json.loads(d.json())
|
||||
return json.loads(d.model_dump_json())
|
||||
|
||||
|
||||
class SafetyClient(Safety):
|
||||
|
|
@ -41,13 +41,13 @@ class SafetyClient(Safety):
|
|||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message]
|
||||
self, shield_id: str, messages: List[Message]
|
||||
) -> RunShieldResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/safety/run_shield",
|
||||
json=dict(
|
||||
shield_type=shield_type,
|
||||
shield_id=shield_id,
|
||||
messages=[encodable_dict(m) for m in messages],
|
||||
),
|
||||
headers={
|
||||
|
|
@ -80,7 +80,7 @@ async def run_main(host: str, port: int, image_path: str = None):
|
|||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
response = await client.run_shield(
|
||||
shield_type="llama_guard",
|
||||
shield_id="Llama-Guard-3-1B",
|
||||
messages=[message],
|
||||
)
|
||||
print(response)
|
||||
|
|
@ -91,13 +91,7 @@ async def run_main(host: str, port: int, image_path: str = None):
|
|||
]:
|
||||
cprint(f"User>{message.content}", "green")
|
||||
response = await client.run_shield(
|
||||
shield_type="llama_guard",
|
||||
messages=[message],
|
||||
)
|
||||
print(response)
|
||||
|
||||
response = await client.run_shield(
|
||||
shield_type="injection_shield",
|
||||
shield_id="llama_guard",
|
||||
messages=[message],
|
||||
)
|
||||
print(response)
|
||||
|
|
|
|||
|
|
@ -5,12 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Protocol
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -37,8 +38,18 @@ class RunShieldResponse(BaseModel):
|
|||
violation: Optional[SafetyViolation] = None
|
||||
|
||||
|
||||
class ShieldStore(Protocol):
|
||||
async def get_shield(self, identifier: str) -> Shield: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Safety(Protocol):
|
||||
@webmethod(route="/safety/run_shield")
|
||||
shield_store: ShieldStore
|
||||
|
||||
@webmethod(route="/safety/run-shield")
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse: ...
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .dataset import * # noqa: F401 F403
|
||||
from .scoring import * # noqa: F401 F403
|
||||
132
llama_stack/apis/scoring/client.py
Normal file
132
llama_stack/apis/scoring/client.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio.client import DatasetIOClient
|
||||
from llama_stack.apis.datasets.client import DatasetsClient
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
|
||||
|
||||
|
||||
class ScoringClient(Scoring):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
self, dataset_id: str, scoring_functions: List[str]
|
||||
) -> ScoreBatchResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/scoring/score_batch",
|
||||
json={
|
||||
"dataset_id": dataset_id,
|
||||
"scoring_functions": scoring_functions,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return ScoreBatchResponse(**response.json())
|
||||
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
) -> ScoreResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/scoring/score",
|
||||
json={
|
||||
"input_rows": input_rows,
|
||||
"scoring_functions": scoring_functions,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return ScoreResponse(**response.json())
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = DatasetsClient(f"http://{host}:{port}")
|
||||
|
||||
# register dataset
|
||||
test_file = (
|
||||
Path(os.path.abspath(__file__)).parent.parent.parent
|
||||
/ "providers/tests/datasetio/test_dataset.csv"
|
||||
)
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
response = await client.register_dataset(
|
||||
DatasetDefWithProvider(
|
||||
identifier="test-dataset",
|
||||
provider_id="meta0",
|
||||
url=URL(
|
||||
uri=test_url,
|
||||
),
|
||||
dataset_schema={
|
||||
"generated_answer": StringType(),
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# list datasets
|
||||
list_dataset = await client.list_datasets()
|
||||
cprint(list_dataset, "blue")
|
||||
|
||||
# datsetio client to get the rows
|
||||
datasetio_client = DatasetIOClient(f"http://{host}:{port}")
|
||||
response = await datasetio_client.get_rows_paginated(
|
||||
dataset_id="test-dataset",
|
||||
rows_in_page=4,
|
||||
page_token=None,
|
||||
filter_condition=None,
|
||||
)
|
||||
cprint(f"Returned {len(response.rows)} rows \n {response}", "green")
|
||||
|
||||
# scoring client to score the rows
|
||||
scoring_client = ScoringClient(f"http://{host}:{port}")
|
||||
response = await scoring_client.score(
|
||||
input_rows=response.rows,
|
||||
scoring_functions=["equality"],
|
||||
)
|
||||
cprint(f"score response={response}", "blue")
|
||||
|
||||
# test scoring batch using datasetio api
|
||||
scoring_client = ScoringClient(f"http://{host}:{port}")
|
||||
response = await scoring_client.score_batch(
|
||||
dataset_id="test-dataset",
|
||||
scoring_functions=["equality"],
|
||||
)
|
||||
cprint(f"score_batch response={response}", "cyan")
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_main(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
60
llama_stack/apis/scoring/scoring.py
Normal file
60
llama_stack/apis/scoring/scoring.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
|
||||
|
||||
# mapping of metric to value
|
||||
ScoringResultRow = Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringResult(BaseModel):
|
||||
score_rows: List[ScoringResultRow]
|
||||
# aggregated metrics to value
|
||||
aggregated_results: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoreBatchResponse(BaseModel):
|
||||
dataset_id: Optional[str] = None
|
||||
results: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoreResponse(BaseModel):
|
||||
# each key in the dict is a scoring function name
|
||||
results: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class ScoringFunctionStore(Protocol):
|
||||
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Scoring(Protocol):
|
||||
scoring_function_store: ScoringFunctionStore
|
||||
|
||||
@webmethod(route="/scoring/score-batch")
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse: ...
|
||||
|
||||
@webmethod(route="/scoring/score")
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse: ...
|
||||
7
llama_stack/apis/scoring_functions/__init__.py
Normal file
7
llama_stack/apis/scoring_functions/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .scoring_functions import * # noqa: F401 F403
|
||||
122
llama_stack/apis/scoring_functions/scoring_functions.py
Normal file
122
llama_stack/apis/scoring_functions/scoring_functions.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||
# with standard metrics so they can be rolled up?
|
||||
@json_schema_type
|
||||
class ScoringFnParamsType(Enum):
|
||||
llm_as_judge = "llm_as_judge"
|
||||
regex_parser = "regex_parser"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = (
|
||||
ScoringFnParamsType.llm_as_judge.value
|
||||
)
|
||||
judge_model: str
|
||||
prompt_template: Optional[str] = None
|
||||
judge_score_regexes: Optional[List[str]] = Field(
|
||||
description="Regexes to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.regex_parser.value] = (
|
||||
ScoringFnParamsType.regex_parser.value
|
||||
)
|
||||
parsing_regexes: Optional[List[str]] = Field(
|
||||
description="Regex to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
ScoringFnParams = Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class CommonScoringFnFields(BaseModel):
|
||||
description: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this definition",
|
||||
)
|
||||
return_type: ParamType = Field(
|
||||
description="The return type of the deterministic function",
|
||||
)
|
||||
params: Optional[ScoringFnParams] = Field(
|
||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringFn(CommonScoringFnFields, Resource):
|
||||
type: Literal[ResourceType.scoring_function.value] = (
|
||||
ResourceType.scoring_function.value
|
||||
)
|
||||
|
||||
@property
|
||||
def scoring_fn_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_scoring_fn_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||
scoring_fn_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_scoring_fn_id: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring-functions/list", method="GET")
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||
|
||||
@webmethod(route="/scoring-functions/get", method="GET")
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
|
||||
|
||||
@webmethod(route="/scoring-functions/register", method="POST")
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
) -> None: ...
|
||||
|
|
@ -25,21 +25,41 @@ class ShieldsClient(Shields):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_shields(self) -> List[ShieldSpec]:
|
||||
async def list_shields(self) -> List[Shield]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/shields/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [ShieldSpec(**x) for x in response.json()]
|
||||
return [Shield(**x) for x in response.json()]
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str],
|
||||
provider_id: Optional[str],
|
||||
params: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/shields/register",
|
||||
json={
|
||||
"shield_id": shield_id,
|
||||
"provider_shield_id": provider_shield_id,
|
||||
"provider_id": provider_id,
|
||||
"params": params,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def get_shield(self, shield_id: str) -> Optional[Shield]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/shields/get",
|
||||
params={
|
||||
"shield_type": shield_type,
|
||||
"shield_id": shield_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
|
@ -49,7 +69,7 @@ class ShieldsClient(Shields):
|
|||
if j is None:
|
||||
return None
|
||||
|
||||
return ShieldSpec(**j)
|
||||
return Shield(**j)
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
|
|
|
|||
|
|
@ -4,25 +4,52 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
class CommonShieldFields(BaseModel):
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldSpec(BaseModel):
|
||||
shield_type: str
|
||||
provider_config: GenericProviderConfig = Field(
|
||||
description="Provider config for the model, including provider_type, and corresponding config. ",
|
||||
)
|
||||
class Shield(CommonShieldFields, Resource):
|
||||
"""A safety shield resource that can be used to check content"""
|
||||
|
||||
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
|
||||
|
||||
@property
|
||||
def shield_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_shield_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class ShieldInput(CommonShieldFields):
|
||||
shield_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_shield_id: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields/list", method="GET")
|
||||
async def list_shields(self) -> List[ShieldSpec]: ...
|
||||
async def list_shields(self) -> List[Shield]: ...
|
||||
|
||||
@webmethod(route="/shields/get", method="GET")
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ...
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
||||
|
||||
@webmethod(route="/shields/register", method="POST")
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield: ...
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.reward_scoring import * # noqa: F403
|
||||
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
|
|
@ -40,12 +39,12 @@ class SyntheticDataGenerationRequest(BaseModel):
|
|||
class SyntheticDataGenerationResponse(BaseModel):
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
synthetic_data: List[ScoredDialogGenerations]
|
||||
synthetic_data: List[Dict[str, Any]]
|
||||
statistics: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
@webmethod(route="/synthetic_data_generation/generate")
|
||||
@webmethod(route="/synthetic-data-generation/generate")
|
||||
def synthetic_data_generate(
|
||||
self,
|
||||
dialogs: List[Message],
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, Union
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -123,9 +123,10 @@ Event = Annotated[
|
|||
]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/telemetry/log_event")
|
||||
@webmethod(route="/telemetry/log-event")
|
||||
async def log_event(self, event: Event) -> None: ...
|
||||
|
||||
@webmethod(route="/telemetry/get_trace", method="GET")
|
||||
@webmethod(route="/telemetry/get-trace", method="GET")
|
||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||
|
|
|
|||
7
llama_stack/apis/version.py
Normal file
7
llama_stack/apis/version.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_STACK_API_VERSION = "alpha"
|
||||
|
|
@ -9,15 +9,27 @@ import asyncio
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import LlamaDownloadInfo
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
DownloadColumn,
|
||||
Progress,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
TransferSpeedColumn,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
|
@ -61,6 +73,13 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
|||
required=False,
|
||||
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-parallel",
|
||||
type=int,
|
||||
required=False,
|
||||
default=3,
|
||||
help="Maximum number of concurrent downloads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore-patterns",
|
||||
type=str,
|
||||
|
|
@ -80,6 +99,245 @@ safetensors files to avoid downloading duplicate weights.
|
|||
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadTask:
|
||||
url: str
|
||||
output_file: str
|
||||
total_size: int = 0
|
||||
downloaded_size: int = 0
|
||||
task_id: Optional[int] = None
|
||||
retries: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CustomTransferSpeedColumn(TransferSpeedColumn):
|
||||
def render(self, task):
|
||||
if task.finished:
|
||||
return "-"
|
||||
return super().render(task)
|
||||
|
||||
|
||||
class ParallelDownloader:
|
||||
def __init__(
|
||||
self,
|
||||
max_concurrent_downloads: int = 3,
|
||||
buffer_size: int = 1024 * 1024,
|
||||
timeout: int = 30,
|
||||
):
|
||||
self.max_concurrent_downloads = max_concurrent_downloads
|
||||
self.buffer_size = buffer_size
|
||||
self.timeout = timeout
|
||||
self.console = Console()
|
||||
self.progress = Progress(
|
||||
TextColumn("[bold blue]{task.description}"),
|
||||
BarColumn(bar_width=40),
|
||||
"[progress.percentage]{task.percentage:>3.1f}%",
|
||||
DownloadColumn(),
|
||||
CustomTransferSpeedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=self.console,
|
||||
expand=True,
|
||||
)
|
||||
self.client_options = {
|
||||
"timeout": httpx.Timeout(timeout),
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
async def retry_with_exponential_backoff(
|
||||
self, task: DownloadTask, func, *args, **kwargs
|
||||
):
|
||||
last_exception = None
|
||||
for attempt in range(task.max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < task.max_retries - 1:
|
||||
wait_time = min(30, 2**attempt) # Cap at 30 seconds
|
||||
self.console.print(
|
||||
f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, "
|
||||
f"retrying in {wait_time} seconds: {str(e)}[/yellow]"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
raise last_exception
|
||||
|
||||
async def get_file_info(
|
||||
self, client: httpx.AsyncClient, task: DownloadTask
|
||||
) -> None:
|
||||
async def _get_info():
|
||||
response = await client.head(
|
||||
task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
try:
|
||||
response = await self.retry_with_exponential_backoff(task, _get_info)
|
||||
|
||||
task.url = str(response.url)
|
||||
task.total_size = int(response.headers.get("Content-Length", 0))
|
||||
|
||||
if task.total_size == 0:
|
||||
raise DownloadError(
|
||||
f"Unable to determine file size for {task.output_file}. "
|
||||
"The server might not support range requests."
|
||||
)
|
||||
|
||||
# Update the progress bar's total size once we know it
|
||||
if task.task_id is not None:
|
||||
self.progress.update(task.task_id, total=task.total_size)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
|
||||
raise
|
||||
|
||||
def verify_file_integrity(self, task: DownloadTask) -> bool:
|
||||
if not os.path.exists(task.output_file):
|
||||
return False
|
||||
return os.path.getsize(task.output_file) == task.total_size
|
||||
|
||||
async def download_chunk(
|
||||
self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
|
||||
) -> None:
|
||||
async def _download_chunk():
|
||||
headers = {"Range": f"bytes={start}-{end}"}
|
||||
async with client.stream(
|
||||
"GET", task.url, headers=headers, **self.client_options
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
with open(task.output_file, "ab") as file:
|
||||
file.seek(start)
|
||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||
file.write(chunk)
|
||||
task.downloaded_size += len(chunk)
|
||||
self.progress.update(
|
||||
task.task_id,
|
||||
completed=task.downloaded_size,
|
||||
)
|
||||
|
||||
try:
|
||||
await self.retry_with_exponential_backoff(task, _download_chunk)
|
||||
except Exception as e:
|
||||
raise DownloadError(
|
||||
f"Failed to download chunk {start}-{end} after "
|
||||
f"{task.max_retries} attempts: {str(e)}"
|
||||
) from e
|
||||
|
||||
async def prepare_download(self, task: DownloadTask) -> None:
|
||||
output_dir = os.path.dirname(task.output_file)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if os.path.exists(task.output_file):
|
||||
task.downloaded_size = os.path.getsize(task.output_file)
|
||||
|
||||
async def download_file(self, task: DownloadTask) -> None:
|
||||
try:
|
||||
async with httpx.AsyncClient(**self.client_options) as client:
|
||||
await self.get_file_info(client, task)
|
||||
|
||||
# Check if file is already downloaded
|
||||
if os.path.exists(task.output_file):
|
||||
if self.verify_file_integrity(task):
|
||||
self.console.print(
|
||||
f"[green]Already downloaded {task.output_file}[/green]"
|
||||
)
|
||||
self.progress.update(task.task_id, completed=task.total_size)
|
||||
return
|
||||
|
||||
await self.prepare_download(task)
|
||||
|
||||
try:
|
||||
# Split the remaining download into chunks
|
||||
chunk_size = 27_000_000_000 # Cloudfront max chunk size
|
||||
chunks = []
|
||||
|
||||
current_pos = task.downloaded_size
|
||||
while current_pos < task.total_size:
|
||||
chunk_end = min(
|
||||
current_pos + chunk_size - 1, task.total_size - 1
|
||||
)
|
||||
chunks.append((current_pos, chunk_end))
|
||||
current_pos = chunk_end + 1
|
||||
|
||||
# Download chunks in sequence
|
||||
for chunk_start, chunk_end in chunks:
|
||||
await self.download_chunk(client, task, chunk_start, chunk_end)
|
||||
|
||||
except Exception as e:
|
||||
raise DownloadError(f"Download failed: {str(e)}") from e
|
||||
|
||||
except Exception as e:
|
||||
self.progress.update(
|
||||
task.task_id, description=f"[red]Failed: {task.output_file}[/red]"
|
||||
)
|
||||
raise DownloadError(
|
||||
f"Download failed for {task.output_file}: {str(e)}"
|
||||
) from e
|
||||
|
||||
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
||||
try:
|
||||
total_remaining_size = sum(
|
||||
task.total_size - task.downloaded_size for task in tasks
|
||||
)
|
||||
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
||||
free_space = shutil.disk_usage(dir_path).free
|
||||
|
||||
# Add 10% buffer for safety
|
||||
required_space = int(total_remaining_size * 1.1)
|
||||
|
||||
if free_space < required_space:
|
||||
self.console.print(
|
||||
f"[red]Not enough disk space. Required: {required_space // (1024 * 1024)} MB, "
|
||||
f"Available: {free_space // (1024 * 1024)} MB[/red]"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
||||
|
||||
async def download_all(self, tasks: List[DownloadTask]) -> None:
|
||||
if not tasks:
|
||||
raise ValueError("No download tasks provided")
|
||||
|
||||
if not self.has_disk_space(tasks):
|
||||
raise DownloadError("Insufficient disk space for downloads")
|
||||
|
||||
failed_tasks = []
|
||||
|
||||
with self.progress:
|
||||
for task in tasks:
|
||||
desc = f"Downloading {Path(task.output_file).name}"
|
||||
task.task_id = self.progress.add_task(
|
||||
desc, total=task.total_size, completed=task.downloaded_size
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
|
||||
|
||||
async def download_with_semaphore(task: DownloadTask):
|
||||
async with semaphore:
|
||||
try:
|
||||
await self.download_file(task)
|
||||
except Exception as e:
|
||||
failed_tasks.append((task, str(e)))
|
||||
|
||||
await asyncio.gather(*(download_with_semaphore(task) for task in tasks))
|
||||
|
||||
if failed_tasks:
|
||||
self.console.print("\n[red]Some downloads failed:[/red]")
|
||||
for task, error in failed_tasks:
|
||||
self.console.print(
|
||||
f"[red]- {Path(task.output_file).name}: {error}[/red]"
|
||||
)
|
||||
raise DownloadError(f"{len(failed_tasks)} downloads failed")
|
||||
|
||||
|
||||
def _hf_download(
|
||||
model: "Model",
|
||||
hf_token: str,
|
||||
|
|
@ -120,67 +378,50 @@ def _hf_download(
|
|||
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
||||
|
||||
|
||||
def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"):
|
||||
def _meta_download(
|
||||
model: "Model",
|
||||
model_id: str,
|
||||
meta_url: str,
|
||||
info: "LlamaDownloadInfo",
|
||||
max_concurrent_downloads: int,
|
||||
):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
output_dir = Path(model_local_dir(model.descriptor()))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# I believe we can use some concurrency here if needed but not sure it is worth it
|
||||
# Create download tasks for each file
|
||||
tasks = []
|
||||
for f in info.files:
|
||||
output_file = str(output_dir / f)
|
||||
url = meta_url.replace("*", f"{info.folder}/{f}")
|
||||
total_size = info.pth_size if "consolidated" in f else 0
|
||||
cprint(f"Downloading `{f}`...", "white")
|
||||
downloader = ResumableDownloader(url, output_file, total_size)
|
||||
asyncio.run(downloader.download())
|
||||
|
||||
print(f"\nSuccessfully downloaded model to {output_dir}")
|
||||
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
|
||||
|
||||
|
||||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
from llama_models.sku_list import llama_meta_net_info, resolve_model
|
||||
|
||||
from .model.safety_models import prompt_guard_download_info, prompt_guard_model_sku
|
||||
|
||||
if args.manifest_file:
|
||||
_download_from_manifest(args.manifest_file)
|
||||
return
|
||||
|
||||
if args.model_id is None:
|
||||
parser.error("Please provide a model id")
|
||||
return
|
||||
|
||||
prompt_guard = prompt_guard_model_sku()
|
||||
if args.model_id == prompt_guard.model_id:
|
||||
model = prompt_guard
|
||||
info = prompt_guard_download_info()
|
||||
else:
|
||||
model = resolve_model(args.model_id)
|
||||
if model is None:
|
||||
parser.error(f"Model {args.model_id} not found")
|
||||
return
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
if args.source == "huggingface":
|
||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||
else:
|
||||
meta_url = args.meta_url
|
||||
if not meta_url:
|
||||
meta_url = input(
|
||||
"Please provide the signed URL you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
tasks.append(
|
||||
DownloadTask(
|
||||
url=url, output_file=output_file, total_size=total_size, max_retries=3
|
||||
)
|
||||
assert meta_url is not None and "llamameta.net" in meta_url
|
||||
_meta_download(model, meta_url, info)
|
||||
)
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
|
||||
cprint(
|
||||
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
||||
"white",
|
||||
)
|
||||
cprint(
|
||||
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
||||
"yellow",
|
||||
)
|
||||
|
||||
|
||||
class ModelEntry(BaseModel):
|
||||
model_id: str
|
||||
files: Dict[str, str]
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class Manifest(BaseModel):
|
||||
|
|
@ -188,7 +429,7 @@ class Manifest(BaseModel):
|
|||
expires_on: datetime
|
||||
|
||||
|
||||
def _download_from_manifest(manifest_file: str):
|
||||
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
with open(manifest_file, "r") as f:
|
||||
|
|
@ -198,143 +439,88 @@ def _download_from_manifest(manifest_file: str):
|
|||
if datetime.now() > manifest.expires_on:
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
console = Console()
|
||||
for entry in manifest.models:
|
||||
print(f"Downloading model {entry.model_id}...")
|
||||
console.print(f"[blue]Downloading model {entry.model_id}...[/blue]")
|
||||
output_dir = Path(model_local_dir(entry.model_id))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if any(output_dir.iterdir()):
|
||||
cprint(f"Output directory {output_dir} is not empty.", "red")
|
||||
console.print(
|
||||
f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
|
||||
)
|
||||
|
||||
while True:
|
||||
resp = input(
|
||||
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
|
||||
)
|
||||
if resp.lower() == "restart" or resp.lower() == "r":
|
||||
if resp.lower() in ["restart", "r"]:
|
||||
shutil.rmtree(output_dir)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
break
|
||||
elif resp.lower() == "continue" or resp.lower() == "c":
|
||||
print("Continuing download...")
|
||||
elif resp.lower() in ["continue", "c"]:
|
||||
console.print("[blue]Continuing download...[/blue]")
|
||||
break
|
||||
else:
|
||||
cprint("Invalid response. Please try again.", "red")
|
||||
console.print("[red]Invalid response. Please try again.[/red]")
|
||||
|
||||
for fname, url in entry.files.items():
|
||||
output_file = str(output_dir / fname)
|
||||
downloader = ResumableDownloader(url, output_file)
|
||||
asyncio.run(downloader.download())
|
||||
# Create download tasks for all files in the manifest
|
||||
tasks = [
|
||||
DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3)
|
||||
for fname, url in entry.files.items()
|
||||
]
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(
|
||||
max_concurrent_downloads=max_concurrent_downloads
|
||||
)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
|
||||
class ResumableDownloader:
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
output_file: str,
|
||||
total_size: int = 0,
|
||||
buffer_size: int = 32 * 1024,
|
||||
):
|
||||
self.url = url
|
||||
self.output_file = output_file
|
||||
self.buffer_size = buffer_size
|
||||
self.total_size = total_size
|
||||
self.downloaded_size = 0
|
||||
self.start_size = 0
|
||||
self.start_time = 0
|
||||
|
||||
async def get_file_info(self, client: httpx.AsyncClient) -> None:
|
||||
if self.total_size > 0:
|
||||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
"""Main download command handler"""
|
||||
try:
|
||||
if args.manifest_file:
|
||||
_download_from_manifest(args.manifest_file, args.max_parallel)
|
||||
return
|
||||
|
||||
# Force disable compression when trying to retrieve file size
|
||||
response = await client.head(
|
||||
self.url, follow_redirects=True, headers={"Accept-Encoding": "identity"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
self.url = str(response.url) # Update URL in case of redirects
|
||||
self.total_size = int(response.headers.get("Content-Length", 0))
|
||||
if self.total_size == 0:
|
||||
raise ValueError(
|
||||
"Unable to determine file size. The server might not support range requests."
|
||||
)
|
||||
if args.model_id is None:
|
||||
parser.error("Please provide a model id")
|
||||
return
|
||||
|
||||
async def download(self) -> None:
|
||||
self.start_time = time.time()
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
await self.get_file_info(client)
|
||||
# Handle comma-separated model IDs
|
||||
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
||||
|
||||
if os.path.exists(self.output_file):
|
||||
self.downloaded_size = os.path.getsize(self.output_file)
|
||||
self.start_size = self.downloaded_size
|
||||
if self.downloaded_size >= self.total_size:
|
||||
print(f"Already downloaded `{self.output_file}`, skipping...")
|
||||
return
|
||||
from llama_models.sku_list import llama_meta_net_info, resolve_model
|
||||
|
||||
additional_size = self.total_size - self.downloaded_size
|
||||
if not self.has_disk_space(additional_size):
|
||||
M = 1024 * 1024 # noqa
|
||||
print(
|
||||
f"Not enough disk space to download `{self.output_file}`. "
|
||||
f"Required: {(additional_size // M):.2f} MB"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Not enough disk space to download `{self.output_file}`"
|
||||
)
|
||||
|
||||
while True:
|
||||
if self.downloaded_size >= self.total_size:
|
||||
break
|
||||
|
||||
# Cloudfront has a max-size limit
|
||||
max_chunk_size = 27_000_000_000
|
||||
request_size = min(
|
||||
self.total_size - self.downloaded_size, max_chunk_size
|
||||
)
|
||||
headers = {
|
||||
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
|
||||
}
|
||||
print(f"Downloading `{self.output_file}`....{headers}")
|
||||
try:
|
||||
async with client.stream(
|
||||
"GET", self.url, headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
with open(self.output_file, "ab") as file:
|
||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||
file.write(chunk)
|
||||
self.downloaded_size += len(chunk)
|
||||
self.print_progress()
|
||||
except httpx.HTTPError as e:
|
||||
print(f"\nDownload interrupted: {e}")
|
||||
print("You can resume the download by running the script again.")
|
||||
except Exception as e:
|
||||
print(f"\nAn error occurred: {e}")
|
||||
|
||||
print(f"\nFinished downloading `{self.output_file}`....")
|
||||
|
||||
def print_progress(self) -> None:
|
||||
percent = (self.downloaded_size / self.total_size) * 100
|
||||
bar_length = 50
|
||||
filled_length = int(bar_length * self.downloaded_size // self.total_size)
|
||||
bar = "█" * filled_length + "-" * (bar_length - filled_length)
|
||||
|
||||
elapsed_time = time.time() - self.start_time
|
||||
M = 1024 * 1024 # noqa
|
||||
|
||||
speed = (
|
||||
(self.downloaded_size - self.start_size) / (elapsed_time * M)
|
||||
if elapsed_time > 0
|
||||
else 0
|
||||
)
|
||||
print(
|
||||
f"\rProgress: |{bar}| {percent:.2f}% "
|
||||
f"({self.downloaded_size // M}/{self.total_size // M} MB) "
|
||||
f"Speed: {speed:.2f} MiB/s",
|
||||
end="",
|
||||
flush=True,
|
||||
from .model.safety_models import (
|
||||
prompt_guard_download_info,
|
||||
prompt_guard_model_sku,
|
||||
)
|
||||
|
||||
def has_disk_space(self, file_size: int) -> bool:
|
||||
dir_path = os.path.dirname(os.path.abspath(self.output_file))
|
||||
free_space = shutil.disk_usage(dir_path).free
|
||||
return free_space > file_size
|
||||
prompt_guard = prompt_guard_model_sku()
|
||||
for model_id in model_ids:
|
||||
if model_id == prompt_guard.model_id:
|
||||
model = prompt_guard
|
||||
info = prompt_guard_download_info()
|
||||
else:
|
||||
model = resolve_model(model_id)
|
||||
if model is None:
|
||||
parser.error(f"Model {model_id} not found")
|
||||
continue
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
if args.source == "huggingface":
|
||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||
else:
|
||||
meta_url = args.meta_url or input(
|
||||
f"Please provide the signed URL for model {model_id} you received via email "
|
||||
f"after visiting https://www.llama.com/llama-downloads/ "
|
||||
f"(e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
)
|
||||
if "llamameta.net" not in meta_url:
|
||||
parser.error("Invalid Meta URL provided")
|
||||
_meta_download(model, model_id, meta_url, info, args.max_parallel)
|
||||
|
||||
except Exception as e:
|
||||
parser.error(f"Download failed: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import argparse
|
|||
from .download import Download
|
||||
from .model import ModelParser
|
||||
from .stack import StackParser
|
||||
from .verify_download import VerifyDownload
|
||||
|
||||
|
||||
class LlamaCLIParser:
|
||||
|
|
@ -27,9 +28,10 @@ class LlamaCLIParser:
|
|||
subparsers = self.parser.add_subparsers(title="subcommands")
|
||||
|
||||
# Add sub-commands
|
||||
Download.create(subparsers)
|
||||
ModelParser.create(subparsers)
|
||||
StackParser.create(subparsers)
|
||||
Download.create(subparsers)
|
||||
VerifyDownload.create(subparsers)
|
||||
|
||||
def parse_args(self) -> argparse.Namespace:
|
||||
return self.parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe
|
|||
from llama_stack.cli.model.download import ModelDownload
|
||||
from llama_stack.cli.model.list import ModelList
|
||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
|
@ -32,3 +33,4 @@ class ModelParser(Subcommand):
|
|||
ModelList.create(subparsers)
|
||||
ModelPromptFormat.create(subparsers)
|
||||
ModelDescribe.create(subparsers)
|
||||
ModelVerifyDownload.create(subparsers)
|
||||
|
|
|
|||
24
llama_stack/cli/model/verify_download.py
Normal file
24
llama_stack/cli/model/verify_download.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class ModelVerifyDownload(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"verify-download",
|
||||
prog="llama model verify-download",
|
||||
description="Verify the downloaded checkpoints' checksums",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
from llama_stack.cli.verify_download import setup_verify_download_parser
|
||||
|
||||
setup_verify_download_parser(self.parser)
|
||||
|
|
@ -9,12 +9,17 @@ import argparse
|
|||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
import os
|
||||
import shutil
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
TEMPLATES_PATH = (
|
||||
Path(os.path.relpath(__file__)).parent.parent.parent / "distribution" / "templates"
|
||||
)
|
||||
import pkg_resources
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
|
|
@ -22,11 +27,10 @@ def available_templates_specs() -> List[BuildConfig]:
|
|||
import yaml
|
||||
|
||||
template_specs = []
|
||||
for p in TEMPLATES_PATH.rglob("*.yaml"):
|
||||
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
||||
with open(p, "r") as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
template_specs.append(build_config)
|
||||
|
||||
return template_specs
|
||||
|
||||
|
||||
|
|
@ -65,174 +69,57 @@ class StackBuild(Subcommand):
|
|||
help="Show the available templates for building a Llama Stack distribution",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Name of the Llama Stack build to override from template config. This name will be used as paths to store configuration files, build conda environments/docker images. If not specified, will use the name from the template config. ",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--image-type",
|
||||
type=str,
|
||||
help="Image Type to use for the build. This can be either conda or docker. If not specified, will use the image type from the template config.",
|
||||
choices=["conda", "docker"],
|
||||
)
|
||||
|
||||
def _get_build_config_from_name(self, args: argparse.Namespace) -> Optional[Path]:
|
||||
if os.getenv("CONDA_PREFIX", ""):
|
||||
conda_dir = (
|
||||
Path(os.getenv("CONDA_PREFIX")).parent / f"llamastack-{args.name}"
|
||||
)
|
||||
else:
|
||||
cprint(
|
||||
"Cannot find CONDA_PREFIX. Trying default conda path ~/.conda/envs...",
|
||||
color="green",
|
||||
)
|
||||
conda_dir = (
|
||||
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.name}"
|
||||
)
|
||||
build_config_file = Path(conda_dir) / f"{args.name}-build.yaml"
|
||||
if build_config_file.exists():
|
||||
return build_config_file
|
||||
|
||||
return None
|
||||
|
||||
def _run_stack_build_command_from_build_config(
|
||||
self, build_config: BuildConfig
|
||||
) -> None:
|
||||
import json
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.build import ApiInput, build_image, ImageType
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
from termcolor import cprint
|
||||
|
||||
# save build.yaml spec for building same distribution again
|
||||
if build_config.image_type == ImageType.docker.value:
|
||||
# docker needs build file to be in the llama-stack repo dir to be able to copy over to the image
|
||||
llama_stack_path = Path(
|
||||
os.path.abspath(__file__)
|
||||
).parent.parent.parent.parent
|
||||
build_dir = llama_stack_path / "tmp/configs/"
|
||||
else:
|
||||
build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}"
|
||||
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
build_file_path = build_dir / f"{build_config.name}-build.yaml"
|
||||
|
||||
with open(build_file_path, "w") as f:
|
||||
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
return_code = build_image(build_config, build_file_path)
|
||||
if return_code != 0:
|
||||
return
|
||||
|
||||
configure_name = (
|
||||
build_config.name
|
||||
if build_config.image_type == "conda"
|
||||
else (f"llamastack-{build_config.name}")
|
||||
)
|
||||
if build_config.image_type == "conda":
|
||||
cprint(
|
||||
f"You can now run `llama stack configure {configure_name}`",
|
||||
color="green",
|
||||
)
|
||||
else:
|
||||
cprint(
|
||||
f"You can now run `llama stack run {build_config.name}`",
|
||||
color="green",
|
||||
)
|
||||
|
||||
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
import json
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.cli.table import print_table
|
||||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
headers = [
|
||||
"Template Name",
|
||||
"Providers",
|
||||
"Description",
|
||||
]
|
||||
|
||||
rows = []
|
||||
for spec in available_templates_specs():
|
||||
rows.append(
|
||||
[
|
||||
spec.name,
|
||||
json.dumps(spec.distribution_spec.providers, indent=2),
|
||||
spec.distribution_spec.description,
|
||||
]
|
||||
)
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
default="conda",
|
||||
)
|
||||
|
||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
import textwrap
|
||||
|
||||
import yaml
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.completion import WordCompleter
|
||||
from prompt_toolkit.validation import Validator
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
|
||||
if args.list_templates:
|
||||
self._run_template_list_cmd(args)
|
||||
return
|
||||
|
||||
if args.template:
|
||||
if not args.name:
|
||||
self.parser.error(
|
||||
"You must specify a name for the build using --name when using a template"
|
||||
)
|
||||
return
|
||||
build_path = TEMPLATES_PATH / f"{args.template}-build.yaml"
|
||||
if not build_path.exists():
|
||||
self.parser.error(
|
||||
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates"
|
||||
)
|
||||
return
|
||||
with open(build_path, "r") as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
build_config.name = args.name
|
||||
if args.image_type:
|
||||
build_config.image_type = args.image_type
|
||||
self._run_stack_build_command_from_build_config(build_config)
|
||||
|
||||
return
|
||||
|
||||
# try to see if we can find a pre-existing build config file through name
|
||||
if args.name:
|
||||
maybe_build_config = self._get_build_config_from_name(args)
|
||||
if maybe_build_config:
|
||||
cprint(
|
||||
f"Building from existing build config for {args.name} in {str(maybe_build_config)}...",
|
||||
"green",
|
||||
)
|
||||
with open(maybe_build_config, "r") as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
self._run_stack_build_command_from_build_config(build_config)
|
||||
available_templates = available_templates_specs()
|
||||
for build_config in available_templates:
|
||||
if build_config.name == args.template:
|
||||
if args.image_type:
|
||||
build_config.image_type = args.image_type
|
||||
else:
|
||||
self.parser.error(
|
||||
f"Please specify a image-type (docker | conda) for {args.template}"
|
||||
)
|
||||
self._run_stack_build_command_from_build_config(
|
||||
build_config, template_name=args.template
|
||||
)
|
||||
return
|
||||
|
||||
self.parser.error(
|
||||
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates"
|
||||
)
|
||||
return
|
||||
|
||||
if not args.config and not args.template:
|
||||
if not args.name:
|
||||
name = prompt(
|
||||
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: len(x) > 0,
|
||||
error_message="Name cannot be empty, please enter a name",
|
||||
),
|
||||
)
|
||||
else:
|
||||
name = args.name
|
||||
name = prompt(
|
||||
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: len(x) > 0,
|
||||
error_message="Name cannot be empty, please enter a name",
|
||||
),
|
||||
)
|
||||
|
||||
image_type = prompt(
|
||||
"> Enter the image type you want your Llama Stack to be built as (docker or conda): ",
|
||||
|
|
@ -244,26 +131,31 @@ class StackBuild(Subcommand):
|
|||
)
|
||||
|
||||
cprint(
|
||||
"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Llama Stack is composed of several APIs working together. Let's select
|
||||
the provider types (implementations) you want to use for these APIs.
|
||||
""",
|
||||
),
|
||||
color="green",
|
||||
)
|
||||
|
||||
print("Tip: use <TAB> to see options for the providers.\n")
|
||||
|
||||
providers = dict()
|
||||
for api, providers_for_api in get_provider_registry().items():
|
||||
available_providers = [
|
||||
x
|
||||
for x in providers_for_api.keys()
|
||||
if x not in ("remote", "remote::sample")
|
||||
]
|
||||
api_provider = prompt(
|
||||
"> Enter provider for the {} API: (default=meta-reference): ".format(
|
||||
api.value
|
||||
),
|
||||
"> Enter provider for API {}: ".format(api.value),
|
||||
completer=WordCompleter(available_providers),
|
||||
complete_while_typing=True,
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in providers_for_api,
|
||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
||||
list(providers_for_api.keys())
|
||||
),
|
||||
),
|
||||
default=(
|
||||
"meta-reference"
|
||||
if "meta-reference" in providers_for_api
|
||||
else list(providers_for_api.keys())[0]
|
||||
lambda x: x in available_providers,
|
||||
error_message="Invalid provider, use <TAB> to see options",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -292,3 +184,153 @@ class StackBuild(Subcommand):
|
|||
self.parser.error(f"Could not parse config file {args.config}: {e}")
|
||||
return
|
||||
self._run_stack_build_command_from_build_config(build_config)
|
||||
|
||||
def _generate_run_config(self, build_config: BuildConfig, build_dir: Path) -> None:
|
||||
"""
|
||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||
"""
|
||||
import json
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
|
||||
apis = list(build_config.distribution_spec.providers.keys())
|
||||
run_config = StackRunConfig(
|
||||
docker_image=(
|
||||
build_config.name
|
||||
if build_config.image_type == ImageType.docker.value
|
||||
else None
|
||||
),
|
||||
image_name=build_config.name,
|
||||
conda_env=(
|
||||
build_config.name
|
||||
if build_config.image_type == ImageType.conda.value
|
||||
else None
|
||||
),
|
||||
apis=apis,
|
||||
providers={},
|
||||
)
|
||||
# build providers dict
|
||||
provider_registry = get_provider_registry()
|
||||
for api in apis:
|
||||
run_config.providers[api] = []
|
||||
provider_types = build_config.distribution_spec.providers[api]
|
||||
if isinstance(provider_types, str):
|
||||
provider_types = [provider_types]
|
||||
|
||||
for i, provider_type in enumerate(provider_types):
|
||||
pid = provider_type.split("::")[-1]
|
||||
|
||||
p = provider_registry[Api(api)][provider_type]
|
||||
if p.deprecation_error:
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
config_type = instantiate_class_type(
|
||||
provider_registry[Api(api)][provider_type].config_class
|
||||
)
|
||||
if hasattr(config_type, "sample_run_config"):
|
||||
config = config_type.sample_run_config(
|
||||
__distro_dir__=f"distributions/{build_config.name}"
|
||||
)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
p_spec = Provider(
|
||||
provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid,
|
||||
provider_type=provider_type,
|
||||
config=config,
|
||||
)
|
||||
run_config.providers[api].append(p_spec)
|
||||
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
run_config_file = build_dir / f"{build_config.name}-run.yaml"
|
||||
|
||||
with open(run_config_file, "w") as f:
|
||||
to_write = json.loads(run_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
cprint(
|
||||
f"You can now edit {run_config_file} and run `llama stack run {run_config_file}`",
|
||||
color="green",
|
||||
)
|
||||
|
||||
def _run_stack_build_command_from_build_config(
|
||||
self, build_config: BuildConfig, template_name: Optional[str] = None
|
||||
) -> None:
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import build_image
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
# save build.yaml spec for building same distribution again
|
||||
build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}"
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
build_file_path = build_dir / f"{build_config.name}-build.yaml"
|
||||
|
||||
with open(build_file_path, "w") as f:
|
||||
to_write = json.loads(build_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
return_code = build_image(build_config, build_file_path)
|
||||
if return_code != 0:
|
||||
return
|
||||
|
||||
if template_name:
|
||||
# copy run.yaml from template to build_dir instead of generating it again
|
||||
template_path = pkg_resources.resource_filename(
|
||||
"llama_stack", f"templates/{template_name}/run.yaml"
|
||||
)
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
run_config_file = build_dir / f"{build_config.name}-run.yaml"
|
||||
shutil.copy(template_path, run_config_file)
|
||||
|
||||
with open(template_path, "r") as f:
|
||||
yaml_content = f.read()
|
||||
|
||||
# Find all ${env.VARIABLE} patterns
|
||||
env_vars = set(re.findall(r"\${env\.([A-Za-z0-9_]+)}", yaml_content))
|
||||
cprint("Build Successful! Next steps: ", color="green")
|
||||
cprint(
|
||||
f" 1. Set the environment variables: {list(env_vars)}",
|
||||
color="green",
|
||||
)
|
||||
cprint(
|
||||
f" 2. Run: `llama stack run {template_name}`",
|
||||
color="green",
|
||||
)
|
||||
else:
|
||||
self._generate_run_config(build_config, build_dir)
|
||||
|
||||
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
import json
|
||||
|
||||
from llama_stack.cli.table import print_table
|
||||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
headers = [
|
||||
"Template Name",
|
||||
"Providers",
|
||||
"Description",
|
||||
]
|
||||
|
||||
rows = []
|
||||
for spec in available_templates_specs():
|
||||
rows.append(
|
||||
[
|
||||
spec.name,
|
||||
json.dumps(spec.distribution_spec.providers, indent=2),
|
||||
spec.distribution_spec.description,
|
||||
]
|
||||
)
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@
|
|||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class StackConfigure(Subcommand):
|
||||
|
|
@ -39,138 +37,10 @@ class StackConfigure(Subcommand):
|
|||
)
|
||||
|
||||
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
docker_image = None
|
||||
|
||||
build_config_file = Path(args.config)
|
||||
|
||||
if build_config_file.exists():
|
||||
with open(build_config_file, "r") as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
self._configure_llama_distribution(build_config, args.output_dir)
|
||||
return
|
||||
|
||||
# if we get here, we need to try to find the conda build config file
|
||||
cprint(
|
||||
f"Could not find {build_config_file}. Trying conda build name instead...",
|
||||
color="green",
|
||||
)
|
||||
|
||||
conda_dir = (
|
||||
Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}"
|
||||
)
|
||||
output = subprocess.check_output(
|
||||
["bash", "-c", "conda info --json -a"]
|
||||
)
|
||||
conda_envs = json.loads(output.decode("utf-8"))["envs"]
|
||||
|
||||
for x in conda_envs:
|
||||
if x.endswith(f"/llamastack-{args.config}"):
|
||||
conda_dir = Path(x)
|
||||
break
|
||||
|
||||
build_config_file = Path(conda_dir) / f"{args.config}-build.yaml"
|
||||
|
||||
if build_config_file.exists():
|
||||
with open(build_config_file, "r") as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
|
||||
self._configure_llama_distribution(build_config, args.output_dir)
|
||||
return
|
||||
|
||||
# if we get here, we need to try to find the docker image
|
||||
cprint(
|
||||
f"Could not find {build_config_file}. Trying docker image name instead...",
|
||||
color="green",
|
||||
)
|
||||
docker_image = args.config
|
||||
builds_dir = BUILDS_BASE_DIR / ImageType.docker.value
|
||||
if args.output_dir:
|
||||
builds_dir = Path(output_dir)
|
||||
os.makedirs(builds_dir, exist_ok=True)
|
||||
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_stack", "distribution/configure_container.sh"
|
||||
)
|
||||
script_args = [script, docker_image, str(builds_dir)]
|
||||
|
||||
return_code = run_with_pty(script_args)
|
||||
|
||||
# we have regenerated the build config file with script, now check if it exists
|
||||
if return_code != 0:
|
||||
self.parser.error(
|
||||
f"Failed to configure container {docker_image} with return code {return_code}. Please run `llama stack build` first. "
|
||||
)
|
||||
return
|
||||
|
||||
return
|
||||
|
||||
def _configure_llama_distribution(
|
||||
self,
|
||||
build_config: BuildConfig,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.configure import configure_api_providers
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
|
||||
builds_dir = BUILDS_BASE_DIR / build_config.image_type
|
||||
if output_dir:
|
||||
builds_dir = Path(output_dir)
|
||||
os.makedirs(builds_dir, exist_ok=True)
|
||||
image_name = build_config.name.replace("::", "-")
|
||||
run_config_file = builds_dir / f"{image_name}-run.yaml"
|
||||
|
||||
if run_config_file.exists():
|
||||
cprint(
|
||||
f"Configuration already exists at `{str(run_config_file)}`. Will overwrite...",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
config = StackRunConfig(**yaml.safe_load(run_config_file.read_text()))
|
||||
else:
|
||||
config = StackRunConfig(
|
||||
built_at=datetime.now(),
|
||||
image_name=image_name,
|
||||
apis_to_serve=[],
|
||||
api_providers={},
|
||||
)
|
||||
|
||||
config = configure_api_providers(config, build_config.distribution_spec)
|
||||
|
||||
config.docker_image = (
|
||||
image_name if build_config.image_type == "docker" else None
|
||||
)
|
||||
config.conda_env = image_name if build_config.image_type == "conda" else None
|
||||
|
||||
with open(run_config_file, "w") as f:
|
||||
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
cprint(
|
||||
f"> YAML configuration has been written to `{run_config_file}`.",
|
||||
color="blue",
|
||||
)
|
||||
|
||||
cprint(
|
||||
f"You can now run `llama stack run {image_name} --port PORT`",
|
||||
color="green",
|
||||
self.parser.error(
|
||||
"""
|
||||
DEPRECATED! llama stack configure has been deprecated.
|
||||
Please use llama stack run <path/to/run.yaml> instead.
|
||||
Please see example run.yaml in /distributions folder.
|
||||
"""
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
class StackRun(Subcommand):
|
||||
|
|
@ -40,16 +42,24 @@ class StackRun(Subcommand):
|
|||
help="Disable IPv6 support",
|
||||
default=False,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
|
||||
default=[],
|
||||
metavar="KEY=VALUE",
|
||||
)
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
BUILDS_BASE_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
)
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
if not args.config:
|
||||
|
|
@ -57,26 +67,43 @@ class StackRun(Subcommand):
|
|||
return
|
||||
|
||||
config_file = Path(args.config)
|
||||
if not config_file.exists() and not args.config.endswith(".yaml"):
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a template
|
||||
config_file = (
|
||||
Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||
)
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to conda dir
|
||||
config_file = Path(
|
||||
BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml"
|
||||
)
|
||||
|
||||
if not config_file.exists() and not args.config.endswith(".yaml"):
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to docker dir
|
||||
config_file = Path(
|
||||
BUILDS_BASE_DIR / ImageType.docker.value / f"{args.config}-run.yaml"
|
||||
)
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(
|
||||
DISTRIBS_BASE_DIR
|
||||
/ f"llamastack-{args.config}"
|
||||
/ f"{args.config}-run.yaml"
|
||||
)
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist. Please run `llama stack build` and `llama stack configure <name>` to generate a run.yaml file"
|
||||
f"File {str(config_file)} does not exist. Please run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
return
|
||||
|
||||
with open(config_file, "r") as f:
|
||||
config = StackRunConfig(**yaml.safe_load(f))
|
||||
print(f"Using config file: {config_file}")
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
|
||||
if config.docker_image:
|
||||
script = pkg_resources.resource_filename(
|
||||
|
|
@ -98,4 +125,16 @@ class StackRun(Subcommand):
|
|||
if args.disable_ipv6:
|
||||
run_args.append("--disable-ipv6")
|
||||
|
||||
for env_var in args.env:
|
||||
if "=" not in env_var:
|
||||
self.parser.error(
|
||||
f"Environment variable '{env_var}' must be in KEY=VALUE format"
|
||||
)
|
||||
return
|
||||
key, value = env_var.split("=", 1) # split on first = only
|
||||
if not key:
|
||||
self.parser.error(f"Environment variable '{env_var}' has empty key")
|
||||
return
|
||||
run_args.extend(["--env", f"{key}={value}"])
|
||||
|
||||
run_with_pty(run_args)
|
||||
|
|
|
|||
|
|
@ -1,105 +0,0 @@
|
|||
from argparse import Namespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from llama_stack.distribution.datatypes import BuildConfig
|
||||
from llama_stack.cli.stack.build import StackBuild
|
||||
|
||||
|
||||
# temporary while we make the tests work
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stack_build():
|
||||
parser = MagicMock()
|
||||
subparsers = MagicMock()
|
||||
return StackBuild(subparsers)
|
||||
|
||||
|
||||
def test_stack_build_initialization(stack_build):
|
||||
assert stack_build.parser is not None
|
||||
assert stack_build.parser.set_defaults.called_once_with(
|
||||
func=stack_build._run_stack_build_command
|
||||
)
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.build.build_image")
|
||||
def test_run_stack_build_command_with_config(
|
||||
mock_build_image, mock_build_config, stack_build
|
||||
):
|
||||
args = Namespace(
|
||||
config="test_config.yaml",
|
||||
template=None,
|
||||
list_templates=False,
|
||||
name=None,
|
||||
image_type="conda",
|
||||
)
|
||||
|
||||
with patch("builtins.open", MagicMock()):
|
||||
with patch("yaml.safe_load") as mock_yaml_load:
|
||||
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
|
||||
mock_build_config.return_value = MagicMock()
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
mock_build_config.assert_called_once()
|
||||
mock_build_image.assert_called_once()
|
||||
|
||||
|
||||
@patch("llama_stack.cli.table.print_table")
|
||||
def test_run_stack_build_command_list_templates(mock_print_table, stack_build):
|
||||
args = Namespace(list_templates=True)
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
mock_print_table.assert_called_once()
|
||||
|
||||
|
||||
@patch("prompt_toolkit.prompt")
|
||||
@patch("llama_stack.distribution.datatypes.BuildConfig")
|
||||
@patch("llama_stack.distribution.build.build_image")
|
||||
def test_run_stack_build_command_interactive(
|
||||
mock_build_image, mock_build_config, mock_prompt, stack_build
|
||||
):
|
||||
args = Namespace(
|
||||
config=None, template=None, list_templates=False, name=None, image_type=None
|
||||
)
|
||||
|
||||
mock_prompt.side_effect = [
|
||||
"test_name",
|
||||
"conda",
|
||||
"meta-reference",
|
||||
"test description",
|
||||
]
|
||||
mock_build_config.return_value = MagicMock()
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
assert mock_prompt.call_count == 4
|
||||
mock_build_config.assert_called_once()
|
||||
mock_build_image.assert_called_once()
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.datatypes.BuildConfig")
|
||||
@patch("llama_stack.distribution.build.build_image")
|
||||
def test_run_stack_build_command_with_template(
|
||||
mock_build_image, mock_build_config, stack_build
|
||||
):
|
||||
args = Namespace(
|
||||
config=None,
|
||||
template="test_template",
|
||||
list_templates=False,
|
||||
name="test_name",
|
||||
image_type="docker",
|
||||
)
|
||||
|
||||
with patch("builtins.open", MagicMock()):
|
||||
with patch("yaml.safe_load") as mock_yaml_load:
|
||||
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
|
||||
mock_build_config.return_value = MagicMock()
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
mock_build_config.assert_called_once()
|
||||
mock_build_image.assert_called_once()
|
||||
133
llama_stack/cli/tests/test_stack_config.py
Normal file
133
llama_stack/cli/tests/test_stack_config.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from llama_stack.distribution.configure import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
parse_and_maybe_upgrade_config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def up_to_date_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
version: {version}
|
||||
image_name: foo
|
||||
apis_to_serve: []
|
||||
built_at: {built_at}
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
safety:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
""".format(
|
||||
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
image_name: foo
|
||||
built_at: {built_at}
|
||||
apis_to_serve: []
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 11434
|
||||
routing_key: Llama3.2-1B-Instruct
|
||||
- provider_type: inline::meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- routing_key: ["shield1", "shield2"]
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- routing_key: vector
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
api_providers:
|
||||
telemetry:
|
||||
provider_type: noop
|
||||
config: {{}}
|
||||
""".format(
|
||||
built_at=datetime.now().isoformat()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
routing_table: {}
|
||||
api_providers: {}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
|
||||
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert "inference" in result.providers
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
||||
result = parse_and_maybe_upgrade_config(old_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert all(
|
||||
api in result.providers
|
||||
for api in ["inference", "safety", "memory", "telemetry"]
|
||||
)
|
||||
safety_provider = result.providers["safety"][0]
|
||||
assert safety_provider.provider_type == "meta-reference"
|
||||
assert "llama_guard_shield" in safety_provider.config
|
||||
|
||||
inference_providers = result.providers["inference"]
|
||||
assert len(inference_providers) == 2
|
||||
assert set(x.provider_id for x in inference_providers) == {
|
||||
"remote::ollama-00",
|
||||
"meta-reference-01",
|
||||
}
|
||||
|
||||
ollama = inference_providers[0]
|
||||
assert ollama.provider_type == "remote::ollama"
|
||||
assert ollama.config["port"] == 11434
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
|
||||
with pytest.raises(ValueError):
|
||||
parse_and_maybe_upgrade_config(invalid_config)
|
||||
144
llama_stack/cli/verify_download.py
Normal file
144
llama_stack/cli/verify_download.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
filename: str
|
||||
expected_hash: str
|
||||
actual_hash: Optional[str]
|
||||
exists: bool
|
||||
matches: bool
|
||||
|
||||
|
||||
class VerifyDownload(Subcommand):
|
||||
"""Llama cli for verifying downloaded model files"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"verify-download",
|
||||
prog="llama verify-download",
|
||||
description="Verify integrity of downloaded model files",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
setup_verify_download_parser(self.parser)
|
||||
|
||||
|
||||
def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
required=True,
|
||||
help="Model ID to verify",
|
||||
)
|
||||
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
|
||||
|
||||
|
||||
def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
|
||||
md5_hash = hashlib.md5()
|
||||
with open(filepath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
md5_hash.update(chunk)
|
||||
return md5_hash.hexdigest()
|
||||
|
||||
|
||||
def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
||||
checksums = {}
|
||||
with open(checklist_path, "r") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
md5sum, filepath = line.strip().split(" ", 1)
|
||||
# Remove leading './' if present
|
||||
filepath = filepath.lstrip("./")
|
||||
checksums[filepath] = md5sum
|
||||
return checksums
|
||||
|
||||
|
||||
def verify_files(
|
||||
model_dir: Path, checksums: Dict[str, str], console: Console
|
||||
) -> List[VerificationResult]:
|
||||
results = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
console=console,
|
||||
) as progress:
|
||||
for filepath, expected_hash in checksums.items():
|
||||
full_path = model_dir / filepath
|
||||
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
|
||||
|
||||
exists = full_path.exists()
|
||||
actual_hash = None
|
||||
matches = False
|
||||
|
||||
if exists:
|
||||
actual_hash = calculate_md5(full_path)
|
||||
matches = actual_hash == expected_hash
|
||||
|
||||
results.append(
|
||||
VerificationResult(
|
||||
filename=filepath,
|
||||
expected_hash=expected_hash,
|
||||
actual_hash=actual_hash,
|
||||
exists=exists,
|
||||
matches=matches,
|
||||
)
|
||||
)
|
||||
|
||||
progress.remove_task(task_id)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
console = Console()
|
||||
model_dir = Path(model_local_dir(args.model_id))
|
||||
checklist_path = model_dir / "checklist.chk"
|
||||
|
||||
if not model_dir.exists():
|
||||
parser.error(f"Model directory not found: {model_dir}")
|
||||
|
||||
if not checklist_path.exists():
|
||||
parser.error(f"Checklist file not found: {checklist_path}")
|
||||
|
||||
checksums = load_checksums(checklist_path)
|
||||
results = verify_files(model_dir, checksums, console)
|
||||
|
||||
# Print results
|
||||
console.print("\nVerification Results:")
|
||||
|
||||
all_good = True
|
||||
for result in results:
|
||||
if not result.exists:
|
||||
console.print(f"[red]❌ {result.filename}: File not found[/red]")
|
||||
all_good = False
|
||||
elif not result.matches:
|
||||
console.print(
|
||||
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
|
||||
f" Expected: {result.expected_hash}\n"
|
||||
f" Got: {result.actual_hash}"
|
||||
)
|
||||
all_good = False
|
||||
else:
|
||||
console.print(f"[green]✓ {result.filename}: Verified[/green]")
|
||||
|
||||
if all_good:
|
||||
console.print("\n[green]All files verified successfully![/green]")
|
||||
|
|
@ -4,26 +4,29 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
from pydantic import BaseModel
|
||||
|
||||
from termcolor import cprint
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
# `llama-stack` is automatically installed by the installation script.
|
||||
SERVER_DEPENDENCIES = [
|
||||
"aiosqlite",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
|
|
@ -36,28 +39,19 @@ class ImageType(Enum):
|
|||
conda = "conda"
|
||||
|
||||
|
||||
class Dependencies(BaseModel):
|
||||
pip_packages: List[str]
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
|
||||
class ApiInput(BaseModel):
|
||||
api: Api
|
||||
provider: str
|
||||
|
||||
|
||||
def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||
package_deps = Dependencies(
|
||||
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
|
||||
pip_packages=SERVER_DEPENDENCIES,
|
||||
)
|
||||
|
||||
# extend package dependencies based on providers spec
|
||||
def get_provider_dependencies(
|
||||
config_providers: Dict[str, List[Provider]]
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
all_providers = get_provider_registry()
|
||||
for (
|
||||
api_str,
|
||||
provider_or_providers,
|
||||
) in build_config.distribution_spec.providers.items():
|
||||
deps = []
|
||||
|
||||
for api_str, provider_or_providers in config_providers.items():
|
||||
providers_for_api = all_providers[Api(api_str)]
|
||||
|
||||
providers = (
|
||||
|
|
@ -67,25 +61,50 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
)
|
||||
|
||||
for provider in providers:
|
||||
if provider not in providers_for_api:
|
||||
# Providers from BuildConfig and RunConfig are subtly different – not great
|
||||
provider_type = (
|
||||
provider if isinstance(provider, str) else provider.provider_type
|
||||
)
|
||||
|
||||
if provider_type not in providers_for_api:
|
||||
raise ValueError(
|
||||
f"Provider `{provider}` is not available for API `{api_str}`"
|
||||
)
|
||||
|
||||
provider_spec = providers_for_api[provider]
|
||||
package_deps.pip_packages.extend(provider_spec.pip_packages)
|
||||
provider_spec = providers_for_api[provider_type]
|
||||
deps.extend(provider_spec.pip_packages)
|
||||
if provider_spec.docker_image:
|
||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||
|
||||
normal_deps = []
|
||||
special_deps = []
|
||||
deps = []
|
||||
for package in package_deps.pip_packages:
|
||||
for package in deps:
|
||||
if "--no-deps" in package or "--index-url" in package:
|
||||
special_deps.append(package)
|
||||
else:
|
||||
deps.append(package)
|
||||
deps = list(set(deps))
|
||||
special_deps = list(set(special_deps))
|
||||
normal_deps.append(package)
|
||||
|
||||
return list(set(normal_deps)), list(set(special_deps))
|
||||
|
||||
|
||||
def print_pip_install_help(providers: Dict[str, List[Provider]]):
|
||||
normal_deps, special_deps = get_provider_dependencies(providers)
|
||||
|
||||
print(
|
||||
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
|
||||
)
|
||||
for special_dep in special_deps:
|
||||
log.info(f"\tpip install {special_dep}")
|
||||
print()
|
||||
|
||||
|
||||
def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||
docker_image = build_config.distribution_spec.docker_image or "python:3.10-slim"
|
||||
|
||||
normal_deps, special_deps = get_provider_dependencies(
|
||||
build_config.distribution_spec.providers
|
||||
)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
if build_config.image_type == ImageType.docker.value:
|
||||
script = pkg_resources.resource_filename(
|
||||
|
|
@ -94,10 +113,10 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
args = [
|
||||
script,
|
||||
build_config.name,
|
||||
package_deps.docker_image,
|
||||
docker_image,
|
||||
str(build_file_path),
|
||||
str(BUILDS_BASE_DIR / ImageType.docker.value),
|
||||
" ".join(deps),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
else:
|
||||
script = pkg_resources.resource_filename(
|
||||
|
|
@ -107,7 +126,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
script,
|
||||
build_config.name,
|
||||
str(build_file_path),
|
||||
" ".join(deps),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
if special_deps:
|
||||
|
|
@ -115,9 +134,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
|
||||
return_code = run_with_pty(args)
|
||||
if return_code != 0:
|
||||
cprint(
|
||||
log.error(
|
||||
f"Failed to build target {build_config.name} with return code {return_code}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
return return_code
|
||||
|
|
|
|||
|
|
@ -1,8 +1,15 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
BUILD_PLATFORM=${BUILD_PLATFORM:-}
|
||||
|
||||
if [ "$#" -lt 4 ]; then
|
||||
echo "Usage: $0 <build_name> <docker_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
|
|
@ -15,7 +22,7 @@ special_pip_deps="$6"
|
|||
set -euo pipefail
|
||||
|
||||
build_name="$1"
|
||||
image_name="llamastack-$build_name"
|
||||
image_name="distribution-$build_name"
|
||||
docker_base=$2
|
||||
build_file_path=$3
|
||||
host_build_dir=$4
|
||||
|
|
@ -30,13 +37,9 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
|||
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
|
||||
DOCKER_BINARY=${DOCKER_BINARY:-docker}
|
||||
DOCKER_OPTS=${DOCKER_OPTS:-}
|
||||
REPO_CONFIGS_DIR="$REPO_DIR/tmp/configs"
|
||||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
llama stack configure $build_file_path
|
||||
cp $host_build_dir/$build_name-run.yaml $REPO_CONFIGS_DIR
|
||||
|
||||
add_to_docker() {
|
||||
local input
|
||||
output_file="$TEMP_DIR/Dockerfile"
|
||||
|
|
@ -62,6 +65,19 @@ RUN apt-get update && apt-get install -y \
|
|||
|
||||
EOF
|
||||
|
||||
# Add pip dependencies first since llama-stack is what will change most often
|
||||
# so we can reuse layers.
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
add_to_docker "RUN pip install --no-cache $pip_dependencies"
|
||||
fi
|
||||
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
add_to_docker "RUN pip install --no-cache $part"
|
||||
done
|
||||
fi
|
||||
|
||||
stack_mount="/app/llama-stack-source"
|
||||
models_mount="/app/llama-models-source"
|
||||
|
||||
|
|
@ -74,9 +90,18 @@ if [ -n "$LLAMA_STACK_DIR" ]; then
|
|||
# Install in editable format. We will mount the source code into the container
|
||||
# so that changes will be reflected in the container without having to do a
|
||||
# rebuild. This is just for development convenience.
|
||||
add_to_docker "RUN pip install -e $stack_mount"
|
||||
add_to_docker "RUN pip install --no-cache -e $stack_mount"
|
||||
else
|
||||
add_to_docker "RUN pip install llama-stack"
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
add_to_docker "RUN pip install fastapi libcst"
|
||||
add_to_docker <<EOF
|
||||
RUN pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
||||
EOF
|
||||
else
|
||||
add_to_docker "RUN pip install --no-cache llama-stack"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
|
|
@ -87,34 +112,20 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|||
|
||||
add_to_docker <<EOF
|
||||
RUN pip uninstall -y llama-models
|
||||
RUN pip install $models_mount
|
||||
RUN pip install --no-cache $models_mount
|
||||
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
add_to_docker "RUN pip install $pip_dependencies"
|
||||
fi
|
||||
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<< "$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
add_to_docker "RUN pip install $part"
|
||||
done
|
||||
fi
|
||||
|
||||
add_to_docker <<EOF
|
||||
|
||||
# This would be good in production but for debugging flexibility lets not add it right now
|
||||
# We need a more solid production ready entrypoint.sh anyway
|
||||
#
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$build_name"]
|
||||
|
||||
EOF
|
||||
|
||||
add_to_docker "ADD tmp/configs/$(basename "$build_file_path") ./llamastack-build.yaml"
|
||||
add_to_docker "ADD tmp/configs/$build_name-run.yaml ./llamastack-run.yaml"
|
||||
|
||||
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile"
|
||||
cat $TEMP_DIR/Dockerfile
|
||||
printf "\n"
|
||||
|
|
@ -127,16 +138,41 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
||||
fi
|
||||
|
||||
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
|
||||
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
||||
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
|
||||
DOCKER_OPTS="$DOCKER_OPTS --security-opt label=disable"
|
||||
fi
|
||||
|
||||
# Set version tag based on PyPI version
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
version_tag="test-$TEST_PYPI_VERSION"
|
||||
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_MODELS_DIR" ]]; then
|
||||
version_tag="dev"
|
||||
else
|
||||
URL="https://pypi.org/pypi/llama-stack/json"
|
||||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||
fi
|
||||
|
||||
# Add version tag to image name
|
||||
image_tag="$image_name:$version_tag"
|
||||
|
||||
# Detect platform architecture
|
||||
ARCH=$(uname -m)
|
||||
if [ -n "$BUILD_PLATFORM" ]; then
|
||||
PLATFORM="--platform $BUILD_PLATFORM"
|
||||
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||
PLATFORM="--platform linux/arm64"
|
||||
elif [ "$ARCH" = "x86_64" ]; then
|
||||
PLATFORM="--platform linux/amd64"
|
||||
else
|
||||
echo "Unsupported architecture: $ARCH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
set -x
|
||||
$DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
||||
$DOCKER_BINARY build $DOCKER_OPTS $PLATFORM -t $image_tag -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts
|
||||
|
||||
# clean up tmp/configs
|
||||
rm -rf $REPO_CONFIGS_DIR
|
||||
set +x
|
||||
|
||||
echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name"
|
||||
echo "Success!"
|
||||
|
|
|
|||
226
llama_stack/distribution/client.py
Normal file
226
llama_stack/distribution/client.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import Any, get_args, get_origin, Type, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
||||
|
||||
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
||||
client_class = create_api_client_class(protocol)
|
||||
impl = client_class(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
def create_api_client_class(protocol) -> Type:
|
||||
if protocol in _CLIENT_CLASSES:
|
||||
return _CLIENT_CLASSES[protocol]
|
||||
|
||||
class APIClient:
|
||||
def __init__(self, base_url: str):
|
||||
print(f"({protocol.__name__}) Connecting to {base_url}")
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.routes = {}
|
||||
|
||||
# Store routes for this protocol
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
|
||||
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
|
||||
|
||||
# TODO: make this more precise, same thing needs to happen in server.py
|
||||
is_streaming = kwargs.get("stream", False)
|
||||
if is_streaming:
|
||||
return self._call_streaming(method_name, *args, **kwargs)
|
||||
else:
|
||||
return await self._call_non_streaming(method_name, *args, **kwargs)
|
||||
|
||||
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||
_, sig = self.routes[method_name]
|
||||
|
||||
if sig.return_annotation is None:
|
||||
return_type = None
|
||||
else:
|
||||
return_type = extract_non_async_iterator_type(sig.return_annotation)
|
||||
assert (
|
||||
return_type
|
||||
), f"Could not extract return type for {sig.return_annotation}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||
response = await client.request(**params)
|
||||
response.raise_for_status()
|
||||
|
||||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
# print(f"({protocol.__name__}) Returning {j}, type {return_type}")
|
||||
return parse_obj_as(return_type, j)
|
||||
|
||||
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||
webmethod, sig = self.routes[method_name]
|
||||
|
||||
return_type = extract_async_iterator_type(sig.return_annotation)
|
||||
assert (
|
||||
return_type
|
||||
), f"Could not extract return type for {sig.return_annotation}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = self.httpx_request_params(method_name, *args, **kwargs)
|
||||
async with client.stream(**params) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
data = json.loads(data)
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield parse_obj_as(return_type, data)
|
||||
except Exception as e:
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
print(data)
|
||||
|
||||
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||
webmethod, sig = self.routes[method_name]
|
||||
|
||||
parameters = list(sig.parameters.values())[1:] # skip `self`
|
||||
for i, param in enumerate(parameters):
|
||||
if i >= len(args):
|
||||
break
|
||||
kwargs[param.name] = args[i]
|
||||
|
||||
url = f"{self.base_url}/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
|
||||
def convert(value):
|
||||
if isinstance(value, list):
|
||||
return [convert(v) for v in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert(v) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
return json.loads(value.model_dump_json())
|
||||
elif isinstance(value, Enum):
|
||||
return value.value
|
||||
else:
|
||||
return value
|
||||
|
||||
params = {}
|
||||
data = {}
|
||||
if webmethod.method == "GET":
|
||||
params.update(kwargs)
|
||||
else:
|
||||
data.update(convert(kwargs))
|
||||
|
||||
ret = dict(
|
||||
method=webmethod.method or "POST",
|
||||
url=url,
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if params:
|
||||
ret["params"] = params
|
||||
if data:
|
||||
ret["json"] = data
|
||||
|
||||
return ret
|
||||
|
||||
# Add protocol methods to the wrapper
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
|
||||
# Name the class after the protocol
|
||||
APIClient.__name__ = f"{protocol.__name__}Client"
|
||||
_CLIENT_CLASSES[protocol] = APIClient
|
||||
return APIClient
|
||||
|
||||
|
||||
# not quite general these methods are
|
||||
def extract_non_async_iterator_type(type_hint):
|
||||
if get_origin(type_hint) is Union:
|
||||
args = get_args(type_hint)
|
||||
for arg in args:
|
||||
if not issubclass(get_origin(arg) or arg, AsyncIterator):
|
||||
return arg
|
||||
return type_hint
|
||||
|
||||
|
||||
def extract_async_iterator_type(type_hint):
|
||||
if get_origin(type_hint) is Union:
|
||||
args = get_args(type_hint)
|
||||
for arg in args:
|
||||
if issubclass(get_origin(arg) or arg, AsyncIterator):
|
||||
inner_args = get_args(arg)
|
||||
return inner_args[0]
|
||||
return None
|
||||
|
||||
|
||||
async def example(model: str = None):
|
||||
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
|
||||
from llama_stack.apis.inference.event_logger import EventLogger
|
||||
|
||||
client_class = create_api_client_class(Inference)
|
||||
client = client_class("http://localhost:5003")
|
||||
|
||||
if not model:
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
|
||||
message = UserMessage(
|
||||
content="hello world, write me a 2 sentence poem about the moon"
|
||||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
|
||||
stream = True
|
||||
iterator = await client.chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async for log in EventLogger().log(iterator):
|
||||
log.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(example())
|
||||
|
|
@ -3,189 +3,190 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
import textwrap
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_models.sku_list import (
|
||||
llama3_1_family,
|
||||
llama3_2_family,
|
||||
llama3_family,
|
||||
resolve_model,
|
||||
safety_models,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.validation import Validator
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.memory.memory import MemoryBankType
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
stack_apis,
|
||||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.providers.impls.meta_reference.safety.config import (
|
||||
MetaReferenceShieldType,
|
||||
)
|
||||
|
||||
|
||||
ALLOWED_MODELS = (
|
||||
llama3_family() + llama3_1_family() + llama3_2_family() + safety_models()
|
||||
)
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_routing_entry_type(config_class: Any):
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
routing_key: str
|
||||
config: config_class
|
||||
def configure_single_provider(
|
||||
registry: Dict[str, ProviderSpec], provider: Provider
|
||||
) -> Provider:
|
||||
provider_spec = registry[provider.provider_type]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
if provider.config:
|
||||
existing = config_type(**provider.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
|
||||
return BaseModelWithConfig
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
return Provider(
|
||||
provider_id=provider.provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
|
||||
|
||||
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
|
||||
"""Get corresponding builtin APIs given provider backed APIs"""
|
||||
res = []
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
if inf.router_api.value in provider_backed_apis:
|
||||
res.append(inf.routing_table_api.value)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
# TODO: make sure we can deal with existing configuration values correctly
|
||||
# instead of just overwriting them
|
||||
def configure_api_providers(
|
||||
config: StackRunConfig, spec: DistributionSpec
|
||||
config: StackRunConfig, build_spec: DistributionSpec
|
||||
) -> StackRunConfig:
|
||||
apis = config.apis_to_serve or list(spec.providers.keys())
|
||||
# append the bulitin routing APIs
|
||||
apis += get_builtin_apis(apis)
|
||||
is_nux = len(config.providers) == 0
|
||||
|
||||
router_api2builtin_api = {
|
||||
inf.router_api.value: inf.routing_table_api.value
|
||||
for inf in builtin_automatically_routed_apis()
|
||||
}
|
||||
if is_nux:
|
||||
logger.info(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||
we need to configure the providers (implementations) you want to use for these APIs.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
||||
provider_registry = get_provider_registry()
|
||||
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
|
||||
|
||||
apis = [v.value for v in stack_apis()]
|
||||
all_providers = get_provider_registry()
|
||||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
|
||||
# configure simple case for with non-routing providers to api_providers
|
||||
for api_str in spec.providers.keys():
|
||||
if api_str not in apis:
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
if api in builtin_apis:
|
||||
continue
|
||||
if api not in provider_registry:
|
||||
raise ValueError(f"Unknown API `{api_str}`")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
api = Api(api_str)
|
||||
|
||||
p = spec.providers[api_str]
|
||||
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
|
||||
|
||||
if isinstance(p, list):
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
|
||||
"yellow",
|
||||
existing_providers = config.providers.get(api_str, [])
|
||||
if existing_providers:
|
||||
logger.info(
|
||||
f"Re-configuring existing providers for API `{api_str}`...",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
p = p[0]
|
||||
|
||||
provider_spec = all_providers[api][p]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
provider_config = config.api_providers.get(api_str)
|
||||
if provider_config:
|
||||
existing = config_type(**provider_config.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
|
||||
if api_str in router_api2builtin_api:
|
||||
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
|
||||
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
|
||||
routing_entries = []
|
||||
if api_str == "inference":
|
||||
if hasattr(cfg, "model"):
|
||||
routing_key = cfg.model
|
||||
else:
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported model your provider has for inference: ",
|
||||
default="Llama3.1-8B-Instruct",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: resolve_model(x) is not None,
|
||||
error_message="Model must be: {}".format(
|
||||
[x.descriptor() for x in ALLOWED_MODELS]
|
||||
),
|
||||
),
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
updated_providers = []
|
||||
for p in existing_providers:
|
||||
logger.info(f"> Configuring provider `({p.provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(provider_registry[api], p)
|
||||
)
|
||||
|
||||
if api_str == "safety":
|
||||
# TODO: add support for other safety providers, and simplify safety provider config
|
||||
if p == "meta-reference":
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=[s.value for s in MetaReferenceShieldType],
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
|
||||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
if api_str == "memory":
|
||||
bank_types = list([x.value for x in MemoryBankType])
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported memory bank type your provider has for memory: ",
|
||||
default="vector",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in bank_types,
|
||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
||||
bank_types
|
||||
),
|
||||
),
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
config.routing_table[api_str] = routing_entries
|
||||
config.api_providers[api_str] = PlaceholderProviderConfig(
|
||||
providers=p if isinstance(p, list) else [p]
|
||||
)
|
||||
logger.info("")
|
||||
else:
|
||||
config.api_providers[api_str] = GenericProviderConfig(
|
||||
provider_type=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
# we are newly configuring this API
|
||||
plist = build_spec.providers.get(api_str, [])
|
||||
plist = plist if isinstance(plist, list) else [plist]
|
||||
|
||||
print("")
|
||||
if not plist:
|
||||
raise ValueError(f"No provider configured for API {api_str}?")
|
||||
|
||||
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
updated_providers = []
|
||||
for i, provider_type in enumerate(plist):
|
||||
if i >= 1:
|
||||
others = ", ".join(plist[i:])
|
||||
logger.info(
|
||||
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(f"> Configuring provider `({provider_type})`")
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
provider_registry[api],
|
||||
Provider(
|
||||
provider_id=(
|
||||
f"{provider_type}-{i:02d}"
|
||||
if len(plist) > 1
|
||||
else provider_type
|
||||
),
|
||||
provider_type=provider_type,
|
||||
config={},
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.info("")
|
||||
|
||||
config.providers[api_str] = updated_providers
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def upgrade_from_routing_table(
|
||||
config_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
def get_providers(entries):
|
||||
return [
|
||||
Provider(
|
||||
provider_id=(
|
||||
f"{entry['provider_type']}-{i:02d}"
|
||||
if len(entries) > 1
|
||||
else entry["provider_type"]
|
||||
),
|
||||
provider_type=entry["provider_type"],
|
||||
config=entry["config"],
|
||||
)
|
||||
for i, entry in enumerate(entries)
|
||||
]
|
||||
|
||||
providers_by_api = {}
|
||||
|
||||
routing_table = config_dict.get("routing_table", {})
|
||||
for api_str, entries in routing_table.items():
|
||||
providers = get_providers(entries)
|
||||
providers_by_api[api_str] = providers
|
||||
|
||||
provider_map = config_dict.get("api_providers", config_dict.get("provider_map", {}))
|
||||
if provider_map:
|
||||
for api_str, provider in provider_map.items():
|
||||
if isinstance(provider, dict) and "provider_type" in provider:
|
||||
providers_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=f"{provider['provider_type']}",
|
||||
provider_type=provider["provider_type"],
|
||||
config=provider["config"],
|
||||
)
|
||||
]
|
||||
|
||||
config_dict["providers"] = providers_by_api
|
||||
|
||||
config_dict.pop("routing_table", None)
|
||||
config_dict.pop("api_providers", None)
|
||||
config_dict.pop("provider_map", None)
|
||||
|
||||
config_dict["apis"] = config_dict["apis_to_serve"]
|
||||
config_dict.pop("apis_to_serve", None)
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
return StackRunConfig(**config_dict)
|
||||
|
||||
if "routing_table" in config_dict:
|
||||
logger.info("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table(config_dict)
|
||||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
return StackRunConfig(**config_dict)
|
||||
|
|
|
|||
|
|
@ -4,35 +4,62 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.eval_tasks import EvalTaskInput
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||
|
||||
|
||||
RoutingKey = Union[str, List[str]]
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
RoutableObject = Union[
|
||||
Model,
|
||||
Shield,
|
||||
MemoryBank,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
EvalTask,
|
||||
]
|
||||
|
||||
|
||||
class RoutableProviderConfig(GenericProviderConfig):
|
||||
routing_key: RoutingKey
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
Union[
|
||||
Model,
|
||||
Shield,
|
||||
MemoryBank,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
EvalTask,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class PlaceholderProviderConfig(BaseModel):
|
||||
"""Placeholder provider config for API whose provider are defined in routing_table"""
|
||||
|
||||
providers: List[str]
|
||||
RoutedProtocol = Union[
|
||||
Inference,
|
||||
Safety,
|
||||
Memory,
|
||||
DatasetIO,
|
||||
Scoring,
|
||||
Eval,
|
||||
]
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
|
|
@ -53,18 +80,16 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
|||
|
||||
|
||||
# Example: /models, /shields
|
||||
@json_schema_type
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
provider_type: str = "routing_table"
|
||||
config_class: str = ""
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
inner_specs: List[ProviderSpec]
|
||||
router_api: Api
|
||||
module: str
|
||||
pip_packages: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionSpec(BaseModel):
|
||||
description: Optional[str] = Field(
|
||||
default="",
|
||||
|
|
@ -80,10 +105,14 @@ in the runtime configuration to help route to the correct provider.""",
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Provider(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
built_at: datetime
|
||||
|
||||
image_name: str = Field(
|
||||
...,
|
||||
|
|
@ -100,36 +129,34 @@ this could be just a hash
|
|||
default=None,
|
||||
description="Reference to the conda environment if this package refers to a conda environment",
|
||||
)
|
||||
apis_to_serve: List[str] = Field(
|
||||
apis: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="""
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
|
||||
api_providers: Dict[
|
||||
str, Union[GenericProviderConfig, PlaceholderProviderConfig]
|
||||
] = Field(
|
||||
providers: Dict[str, List[Provider]] = Field(
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package.
|
||||
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
||||
can be instantiated multiple times (with different configs) if necessary.
|
||||
""",
|
||||
)
|
||||
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
|
||||
default_factory=dict,
|
||||
metadata_store: Optional[KVStoreConfig] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
||||
E.g. The following is a ProviderRoutingEntry for models:
|
||||
- routing_key: Llama3.1-8B-Instruct
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
""",
|
||||
Configuration for the persistence store used by the distribution registry. If not specified,
|
||||
a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: List[ModelInput] = Field(default_factory=list)
|
||||
shields: List[ShieldInput] = Field(default_factory=list)
|
||||
memory_banks: List[MemoryBankInput] = Field(default_factory=list)
|
||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
name: str
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
|
|
@ -35,6 +35,18 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
routing_table_api=Api.memory_banks,
|
||||
router_api=Api.memory,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.datasets,
|
||||
router_api=Api.datasetio,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.scoring_functions,
|
||||
router_api=Api.scoring,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.eval_tasks,
|
||||
router_api=Api.eval,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -50,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {
|
||||
"remote": remote_provider_spec(api),
|
||||
**{a.provider_type: a for a in module.available_providers()},
|
||||
}
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -6,45 +6,58 @@
|
|||
|
||||
from typing import Dict, List
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = DistributionInspectImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class DistributionInspectImpl(Inspect):
|
||||
def __init__(self):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_providers = get_provider_registry()
|
||||
for api, providers in all_providers.items():
|
||||
ret[api.value] = [
|
||||
for api, providers in run_config.providers.items():
|
||||
ret[api] = [
|
||||
ProviderInfo(
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
description="Passthrough" if is_passthrough(p) else "",
|
||||
)
|
||||
for p in providers.values()
|
||||
for p in providers
|
||||
]
|
||||
|
||||
return ret
|
||||
|
||||
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = {}
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
for api, endpoints in all_endpoints.items():
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
ret[api.value] = [
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
providers=[],
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,11 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_THREAD_LOCAL = threading.local()
|
||||
|
||||
|
||||
|
|
@ -32,7 +35,7 @@ class NeedsRequestProviderData:
|
|||
provider_data = validator(**val)
|
||||
return provider_data
|
||||
except Exception as e:
|
||||
print("Error parsing provider data", e)
|
||||
log.error("Error parsing provider data", e)
|
||||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str]):
|
||||
|
|
@ -51,7 +54,7 @@ def set_request_provider_data(headers: Dict[str, str]):
|
|||
try:
|
||||
val = json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
print("Provider data not encoded as a JSON object!", val)
|
||||
log.error("Provider data not encoded as a JSON object!", val)
|
||||
return
|
||||
|
||||
_THREAD_LOCAL.provider_data_header_value = val
|
||||
|
|
|
|||
|
|
@ -4,159 +4,287 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.inspect import DistributionInspectImpl
|
||||
|
||||
import logging
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.eval_tasks import EvalTasks
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def api_protocol_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
Api.memory: Memory,
|
||||
Api.memory_banks: MemoryBanks,
|
||||
Api.models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
Api.telemetry: Telemetry,
|
||||
Api.datasetio: DatasetIO,
|
||||
Api.datasets: Datasets,
|
||||
Api.scoring: Scoring,
|
||||
Api.scoring_functions: ScoringFunctions,
|
||||
Api.eval: Eval,
|
||||
Api.eval_tasks: EvalTasks,
|
||||
}
|
||||
|
||||
|
||||
def additional_protocols_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||
Api.scoring: (
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
ScoringFunctions,
|
||||
Api.scoring_functions,
|
||||
),
|
||||
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
|
||||
}
|
||||
|
||||
|
||||
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
|
||||
class ProviderWithSpec(Provider):
|
||||
spec: ProviderSpec
|
||||
|
||||
|
||||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
all_providers = get_provider_registry()
|
||||
specs = {}
|
||||
configs = {}
|
||||
|
||||
for api_str, config in run_config.api_providers.items():
|
||||
api = Api(api_str)
|
||||
|
||||
# TODO: check that these APIs are not in the routing table part of the config
|
||||
providers = all_providers[api]
|
||||
|
||||
# skip checks for API whose provider config is specified in routing_table
|
||||
if isinstance(config, PlaceholderProviderConfig):
|
||||
continue
|
||||
|
||||
if config.provider_type not in providers:
|
||||
raise ValueError(
|
||||
f"Provider `{config.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[config.provider_type]
|
||||
configs[api] = config
|
||||
|
||||
apis_to_serve = run_config.apis_to_serve or set(
|
||||
list(specs.keys()) + list(run_config.routing_table.keys())
|
||||
routing_table_apis = set(
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
||||
|
||||
providers_with_specs = {}
|
||||
|
||||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(
|
||||
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
|
||||
)
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(
|
||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
log.error(p.deprecation_error, "red", attrs=["bold"])
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
elif p.deprecation_warning:
|
||||
log.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
**(provider.model_dump()),
|
||||
)
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||
providers_with_specs[key] = specs
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys())
|
||||
+ [x.value for x in routing_table_apis]
|
||||
+ [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
for info in builtin_automatically_routed_apis():
|
||||
source_api = info.routing_table_api
|
||||
|
||||
assert (
|
||||
source_api not in specs
|
||||
), f"Routing table API {source_api} specified in wrong place?"
|
||||
assert (
|
||||
info.router_api not in specs
|
||||
), f"Auto-routed API {info.router_api} specified in wrong place?"
|
||||
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
if info.router_api.value not in run_config.routing_table:
|
||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||
providers_with_specs[info.routing_table_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__routing_table__",
|
||||
provider_type="__routing_table__",
|
||||
config={},
|
||||
spec=RoutingTableProviderSpec(
|
||||
api=info.routing_table_api,
|
||||
router_api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
deps__=([f"inner-{info.router_api.value}"]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
routing_table = run_config.routing_table[info.router_api.value]
|
||||
providers_with_specs[info.router_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__autorouted__",
|
||||
provider_type="__autorouted__",
|
||||
config={},
|
||||
spec=AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=([info.routing_table_api.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
providers = all_providers[info.router_api]
|
||||
|
||||
inner_specs = []
|
||||
inner_deps = []
|
||||
for rt_entry in routing_table:
|
||||
if rt_entry.provider_type not in providers:
|
||||
raise ValueError(
|
||||
f"Provider `{rt_entry.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_type])
|
||||
inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
|
||||
|
||||
specs[source_api] = RoutingTableProviderSpec(
|
||||
api=source_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=inner_deps,
|
||||
inner_specs=inner_specs,
|
||||
sorted_providers = topological_sort(
|
||||
{k: v.values() for k, v in providers_with_specs.items()}
|
||||
)
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
"inspect",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={
|
||||
"run_config": run_config.dict(),
|
||||
},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=([x.value for x in apis]),
|
||||
),
|
||||
),
|
||||
)
|
||||
configs[source_api] = routing_table
|
||||
|
||||
specs[info.router_api] = AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=source_api,
|
||||
api_dependencies=[source_api],
|
||||
)
|
||||
configs[info.router_api] = {}
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
||||
for spec in sorted_specs:
|
||||
print(f" {spec.api}: {spec.provider_type}")
|
||||
print("")
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||
impl = await instantiate_provider(spec, deps, configs[api])
|
||||
|
||||
impls[api] = impl
|
||||
|
||||
impls[Api.inspect] = DistributionInspectImpl()
|
||||
specs[Api.inspect] = InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__distribution_builtin__",
|
||||
config_class="",
|
||||
module="",
|
||||
)
|
||||
|
||||
return impls, specs
|
||||
log.info(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
log.info(f" {api_str} => {provider.provider_id}")
|
||||
log.info("")
|
||||
|
||||
impls = {}
|
||||
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
||||
for api_str, provider in sorted_providers:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
|
||||
inner_impls = {}
|
||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[
|
||||
f"inner-{provider.spec.router_api.value}"
|
||||
]
|
||||
|
||||
impl = await instantiate_provider(
|
||||
provider,
|
||||
deps,
|
||||
inner_impls,
|
||||
dist_registry,
|
||||
)
|
||||
# TODO: ugh slightly redesign this shady looking code
|
||||
if "inner-" in api_str:
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
else:
|
||||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||
by_id = {x.api: x for x in providers}
|
||||
def topological_sort(
|
||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
||||
) -> List[ProviderWithSpec]:
|
||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
||||
api_str, providers = kv
|
||||
visited.add(api_str)
|
||||
|
||||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||
visited.add(a.api)
|
||||
deps = []
|
||||
for provider in providers:
|
||||
for dep in provider.spec.deps__:
|
||||
deps.append(dep)
|
||||
|
||||
for api in a.api_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
for dep in deps:
|
||||
if dep not in visited:
|
||||
dfs((dep, providers_with_specs[dep]), visited, stack)
|
||||
|
||||
stack.append(a.api)
|
||||
stack.append(api_str)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
for a in providers:
|
||||
if a.api not in visited:
|
||||
dfs(a, visited, stack)
|
||||
for api_str, providers in providers_with_specs.items():
|
||||
if api_str not in visited:
|
||||
dfs((api_str, providers), visited, stack)
|
||||
|
||||
return [by_id[x] for x in stack]
|
||||
flattened = []
|
||||
for api_str in stack:
|
||||
for provider in providers_with_specs[api_str]:
|
||||
flattened.append((api_str, provider))
|
||||
return flattened
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
provider: ProviderWithSpec,
|
||||
deps: Dict[str, Any],
|
||||
provider_config: Union[GenericProviderConfig, RoutingTable],
|
||||
inner_impls: Dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
):
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
provider_spec = provider.spec
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
config = config_type(**provider.config)
|
||||
|
||||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
||||
|
|
@ -165,31 +293,95 @@ async def instantiate_provider(
|
|||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
method = "get_routing_table_impl"
|
||||
|
||||
assert isinstance(provider_config, List)
|
||||
routing_table = provider_config
|
||||
|
||||
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in routing_table:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_type],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, routing_table, deps]
|
||||
args = [provider_spec.api, inner_impls, deps, dist_registry]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_id__ = provider.provider_id
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
if (
|
||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
and provider_spec.api in additional_protocols
|
||||
):
|
||||
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||
check_protocol_compliance(impl, additional_api)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||
missing_methods = []
|
||||
|
||||
mro = type(obj).__mro__
|
||||
for name, value in inspect.getmembers(protocol):
|
||||
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
||||
if not hasattr(obj, name):
|
||||
missing_methods.append((name, "missing"))
|
||||
elif not callable(getattr(obj, name)):
|
||||
missing_methods.append((name, "not_callable"))
|
||||
else:
|
||||
# Check if the method signatures are compatible
|
||||
obj_method = getattr(obj, name)
|
||||
proto_sig = inspect.signature(value)
|
||||
obj_sig = inspect.signature(obj_method)
|
||||
|
||||
proto_params = set(proto_sig.parameters)
|
||||
proto_params.discard("self")
|
||||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
log.error(
|
||||
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
||||
)
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
method_owner = next(
|
||||
(cls for cls in mro if name in cls.__dict__), None
|
||||
)
|
||||
if (
|
||||
method_owner is None
|
||||
or method_owner.__name__ == protocol.__name__
|
||||
):
|
||||
missing_methods.append((name, "not_actually_implemented"))
|
||||
|
||||
if missing_methods:
|
||||
raise ValueError(
|
||||
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
||||
)
|
||||
|
||||
|
||||
async def resolve_remote_stack_impls(
|
||||
config: RemoteProviderConfig,
|
||||
apis: List[str],
|
||||
) -> Dict[Api, Any]:
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
impls = {}
|
||||
for api_str in apis:
|
||||
api = Api(api_str)
|
||||
impls[api] = await get_client_impl(
|
||||
protocols[api],
|
||||
config,
|
||||
{},
|
||||
)
|
||||
if api in additional_protocols:
|
||||
_, additional_protocol, additional_api = additional_protocols[api]
|
||||
impls[additional_api] = await get_client_impl(
|
||||
additional_protocol,
|
||||
config,
|
||||
{},
|
||||
)
|
||||
|
||||
return impls
|
||||
|
|
|
|||
|
|
@ -4,43 +4,62 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
|
||||
from .routing_tables import (
|
||||
DatasetsRoutingTable,
|
||||
EvalTasksRoutingTable,
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ScoringFunctionsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
|
||||
async def get_routing_table_impl(
|
||||
api: Api,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> Any:
|
||||
from .routing_tables import (
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
api_to_tables = {
|
||||
"memory_banks": MemoryBanksRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
"datasets": DatasetsRoutingTable,
|
||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||
"eval_tasks": EvalTasksRoutingTable,
|
||||
}
|
||||
|
||||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](inner_impls, routing_table_config)
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
InferenceRouter,
|
||||
MemoryRouter,
|
||||
SafetyRouter,
|
||||
ScoringRouter,
|
||||
)
|
||||
|
||||
api_to_routers = {
|
||||
"memory": MemoryRouter,
|
||||
"inference": InferenceRouter,
|
||||
"safety": SafetyRouter,
|
||||
"datasetio": DatasetIORouter,
|
||||
"scoring": ScoringRouter,
|
||||
"eval": EvalRouter,
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
|
|
|||
|
|
@ -4,24 +4,27 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.datasetio.datasetio import DatasetIO
|
||||
from llama_stack.apis.memory_banks.memory_banks import BankParams
|
||||
from llama_stack.distribution.datatypes import RoutingTable
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
"""Routes to an provider based on the memory bank type"""
|
||||
"""Routes to an provider based on the memory bank identifier"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
self.bank_id_to_type = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
|
@ -29,32 +32,19 @@ class MemoryRouter(Memory):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def get_provider_from_bank_id(self, bank_id: str) -> Any:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
provider = self.routing_table.get_provider_impl(bank_type)
|
||||
if not provider:
|
||||
raise ValueError(f"Could not find provider for {bank_type}")
|
||||
return provider
|
||||
|
||||
async def create_memory_bank(
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_type = config.type
|
||||
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
|
||||
name, config, url
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memorybank_id: Optional[str] = None,
|
||||
) -> None:
|
||||
await self.routing_table.register_memory_bank(
|
||||
memory_bank_id,
|
||||
params,
|
||||
provider_id,
|
||||
provider_memorybank_id,
|
||||
)
|
||||
self.bank_id_to_type[bank.bank_id] = bank_type
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
provider = self.get_provider_from_bank_id(bank_id)
|
||||
return await provider.get_memory_bank(bank_id)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
|
@ -62,7 +52,7 @@ class MemoryRouter(Memory):
|
|||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
return await self.get_provider_from_bank_id(bank_id).insert_documents(
|
||||
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
|
||||
bank_id, documents, ttl_seconds
|
||||
)
|
||||
|
||||
|
|
@ -72,7 +62,7 @@ class MemoryRouter(Memory):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
return await self.get_provider_from_bank_id(bank_id).query_documents(
|
||||
return await self.routing_table.get_provider_impl(bank_id).query_documents(
|
||||
bank_id, query, params
|
||||
)
|
||||
|
||||
|
|
@ -92,11 +82,23 @@ class InferenceRouter(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
await self.routing_table.register_model(
|
||||
model_id, provider_model_id, provider_id, metadata
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
|
|
@ -104,44 +106,52 @@ class InferenceRouter(Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
params = dict(
|
||||
model=model,
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||
**params
|
||||
):
|
||||
yield chunk
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
else:
|
||||
return await provider.chat_completion(**params)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
return await self.routing_table.get_provider_impl(model).completion(
|
||||
model=model,
|
||||
) -> AsyncGenerator:
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.completion(**params))
|
||||
else:
|
||||
return await provider.completion(**params)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
return await self.routing_table.get_provider_impl(model).embeddings(
|
||||
model=model,
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
)
|
||||
|
||||
|
|
@ -159,14 +169,178 @@ class SafetyRouter(Safety):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
return await self.routing_table.register_shield(
|
||||
shield_id, provider_shield_id, provider_id, params
|
||||
)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
||||
shield_type=shield_type,
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
||||
|
||||
class DatasetIORouter(DatasetIO):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
dataset_id
|
||||
).get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=rows_in_page,
|
||||
page_token=page_token,
|
||||
filter_condition=filter_condition,
|
||||
)
|
||||
|
||||
|
||||
class ScoringRouter(Scoring):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
if save_results_dataset:
|
||||
raise NotImplementedError("Save results dataset not implemented yet")
|
||||
|
||||
return ScoreBatchResponse(
|
||||
results=res,
|
||||
)
|
||||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
return ScoreResponse(results=res)
|
||||
|
||||
|
||||
class EvalRouter(Eval):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: AppEvalTaskConfig,
|
||||
) -> Job:
|
||||
return await self.routing_table.get_provider_impl(task_id).run_eval(
|
||||
task_id=task_id,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: EvalTaskConfig,
|
||||
) -> EvaluateResponse:
|
||||
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
|
||||
task_id=task_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
async def job_status(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
return await self.routing_table.get_provider_impl(task_id).job_status(
|
||||
task_id, job_id
|
||||
)
|
||||
|
||||
async def job_cancel(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
await self.routing_table.get_provider_impl(task_id).job_cancel(
|
||||
task_id,
|
||||
job_id,
|
||||
)
|
||||
|
||||
async def job_result(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
return await self.routing_table.get_provider_impl(task_id).job_result(
|
||||
task_id,
|
||||
job_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,141 +4,427 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
return p.__provider_spec__.api
|
||||
|
||||
|
||||
# TODO: this should return the registered object for all APIs
|
||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||
|
||||
api = get_impl_api(p)
|
||||
|
||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||
|
||||
if api == Api.inference:
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
return await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
return await p.register_memory_bank(obj)
|
||||
elif api == Api.datasetio:
|
||||
return await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
return await p.register_scoring_function(obj)
|
||||
elif api == Api.eval:
|
||||
return await p.register_eval_task(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.memory:
|
||||
return await p.unregister_memory_bank(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
else:
|
||||
raise ValueError(f"Unregister not supported for {api}")
|
||||
|
||||
|
||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||
|
||||
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
inner_impls: List[Tuple[RoutingKey, Any]],
|
||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> None:
|
||||
self.unique_providers = []
|
||||
self.providers = {}
|
||||
self.routing_keys = []
|
||||
|
||||
for key, impl in inner_impls:
|
||||
keys = key if isinstance(key, list) else [key]
|
||||
self.unique_providers.append((keys, impl))
|
||||
|
||||
for k in keys:
|
||||
if k in self.providers:
|
||||
raise ValueError(f"Duplicate routing key {k}")
|
||||
self.providers[k] = impl
|
||||
self.routing_keys.append(k)
|
||||
|
||||
self.routing_table_config = routing_table_config
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.dist_registry = dist_registry
|
||||
|
||||
async def initialize(self) -> None:
|
||||
for keys, p in self.unique_providers:
|
||||
spec = p.__provider_spec__
|
||||
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
|
||||
continue
|
||||
|
||||
await p.validate_routing_keys(keys)
|
||||
async def add_objects(
|
||||
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
||||
) -> None:
|
||||
for obj in objs:
|
||||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
# Register all objects from providers
|
||||
for pid, p in self.impls_by_provider_id.items():
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
elif api == Api.eval:
|
||||
p.eval_task_store = self
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for _, p in self.unique_providers:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Any:
|
||||
if routing_key not in self.providers:
|
||||
raise ValueError(f"Could not find provider for {routing_key}")
|
||||
return self.providers[routing_key]
|
||||
def get_provider_impl(
|
||||
self, routing_key: str, provider_id: Optional[str] = None
|
||||
) -> Any:
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, MemoryBanksRoutingTable):
|
||||
return ("Memory", "memory_bank")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
return ("Scoring", "scoring_function")
|
||||
elif isinstance(self, EvalTasksRoutingTable):
|
||||
return ("Eval", "eval_task")
|
||||
else:
|
||||
raise ValueError("Unknown routing table type")
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
apiname, objtype = apiname_object()
|
||||
|
||||
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == routing_key:
|
||||
return entry
|
||||
return None
|
||||
# Get objects from disk registry
|
||||
obj = self.dist_registry.get_cached(objtype, routing_key)
|
||||
if not obj:
|
||||
provider_ids = list(self.impls_by_provider_id.keys())
|
||||
if len(provider_ids) > 1:
|
||||
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
||||
else:
|
||||
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||
raise ValueError(
|
||||
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
||||
)
|
||||
|
||||
if not provider_id or provider_id == obj.provider_id:
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
async def get_object_by_identifier(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
# Get from disk registry
|
||||
obj = await self.dist_registry.get(type, identifier)
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
await unregister_object_from_provider(
|
||||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
|
||||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
# Get existing objects from registry
|
||||
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
|
||||
|
||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
return registered_obj
|
||||
|
||||
else:
|
||||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[Model]:
|
||||
return await self.get_all_with_type("model")
|
||||
|
||||
async def list_models(self) -> List[ModelServingSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
model_id = entry.routing_key
|
||||
specs.append(
|
||||
ModelServingSpec(
|
||||
llama_model=resolve_model(model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
async def get_model(self, identifier: str) -> Optional[Model]:
|
||||
return await self.get_object_by_identifier("model", identifier)
|
||||
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == core_model_id:
|
||||
return ModelServingSpec(
|
||||
llama_model=resolve_model(core_model_id),
|
||||
provider_config=entry,
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Model:
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this model
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||
)
|
||||
return None
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> List[Shield]:
|
||||
return await self.get_all_with_type(ResourceType.shield.value)
|
||||
|
||||
async def list_shields(self) -> List[ShieldSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
if isinstance(entry.routing_key, list):
|
||||
for k in entry.routing_key:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=k,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||
return await self.get_object_by_identifier("shield", identifier)
|
||||
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
if provider_shield_id is None:
|
||||
provider_shield_id = shield_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
return specs
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == shield_type:
|
||||
return ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
if params is None:
|
||||
params = {}
|
||||
shield = Shield(
|
||||
identifier=shield_id,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
params=params,
|
||||
)
|
||||
await self.register_object(shield)
|
||||
return shield
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
return await self.get_all_with_type(ResourceType.memory_bank.value)
|
||||
|
||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
specs.append(
|
||||
MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
|
||||
|
||||
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == bank_type:
|
||||
return MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank:
|
||||
if provider_memory_bank_id is None:
|
||||
provider_memory_bank_id = memory_bank_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
return None
|
||||
memory_bank = parse_obj_as(
|
||||
MemoryBank,
|
||||
{
|
||||
"identifier": memory_bank_id,
|
||||
"type": ResourceType.memory_bank.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": provider_memory_bank_id,
|
||||
**params.model_dump(),
|
||||
},
|
||||
)
|
||||
await self.register_object(memory_bank)
|
||||
return memory_bank
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||
if existing_bank is None:
|
||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
||||
await self.unregister_object(existing_bank)
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> List[Dataset]:
|
||||
return await self.get_all_with_type(ResourceType.dataset.value)
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_schema: Dict[str, ParamType],
|
||||
url: URL,
|
||||
provider_dataset_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if provider_dataset_id is None:
|
||||
provider_dataset_id = dataset_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this dataset
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
dataset = Dataset(
|
||||
identifier=dataset_id,
|
||||
provider_resource_id=provider_dataset_id,
|
||||
provider_id=provider_id,
|
||||
dataset_schema=dataset_schema,
|
||||
url=url,
|
||||
metadata=metadata,
|
||||
)
|
||||
await self.register_object(dataset)
|
||||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
return await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
) -> None:
|
||||
if provider_scoring_fn_id is None:
|
||||
provider_scoring_fn_id = scoring_fn_id
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
scoring_fn = ScoringFn(
|
||||
identifier=scoring_fn_id,
|
||||
description=description,
|
||||
return_type=return_type,
|
||||
provider_resource_id=provider_scoring_fn_id,
|
||||
provider_id=provider_id,
|
||||
params=params,
|
||||
)
|
||||
scoring_fn.provider_id = provider_id
|
||||
await self.register_object(scoring_fn)
|
||||
|
||||
|
||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||
async def list_eval_tasks(self) -> List[EvalTask]:
|
||||
return await self.get_all_with_type(ResourceType.eval_task.value)
|
||||
|
||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
|
||||
return await self.get_object_by_identifier("eval_task", name)
|
||||
|
||||
async def register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
provider_eval_task_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
if provider_eval_task_id is None:
|
||||
provider_eval_task_id = eval_task_id
|
||||
eval_task = EvalTask(
|
||||
identifier=eval_task_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
metadata=metadata,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_eval_task_id,
|
||||
)
|
||||
await self.register_object(eval_task)
|
||||
|
|
|
|||
|
|
@ -9,15 +9,9 @@ from typing import Dict, List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
|
@ -31,18 +25,7 @@ class ApiEndpoint(BaseModel):
|
|||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
protocols = {
|
||||
Api.inference: Inference,
|
||||
Api.safety: Safety,
|
||||
Api.agents: Agents,
|
||||
Api.memory: Memory,
|
||||
Api.telemetry: Telemetry,
|
||||
Api.models: Models,
|
||||
Api.shields: Shields,
|
||||
Api.memory_banks: MemoryBanks,
|
||||
Api.inspect: Inspect,
|
||||
}
|
||||
|
||||
protocols = api_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
endpoints = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
|
@ -52,7 +35,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
continue
|
||||
|
||||
webmethod = method.__webmethod__
|
||||
route = webmethod.route
|
||||
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
|
||||
if webmethod.method == "GET":
|
||||
method = "get"
|
||||
|
|
|
|||
|
|
@ -4,62 +4,69 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from ssl import SSLError
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.providers.inline.meta_reference.telemetry.console import (
|
||||
ConsoleConfig,
|
||||
ConsoleTelemetryImpl,
|
||||
)
|
||||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
if hasattr(typ, "__origin__"):
|
||||
origin = typ.__origin__
|
||||
if isinstance(origin, type):
|
||||
return issubclass(
|
||||
origin,
|
||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||
)
|
||||
return False
|
||||
return isinstance(
|
||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||
)
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
log = file if hasattr(file, "write") else sys.stderr
|
||||
traceback.print_stack(file=log)
|
||||
log.write(warnings.formatwarning(message, category, filename, lineno, line))
|
||||
|
||||
|
||||
if os.environ.get("LLAMA_STACK_TRACE_WARNINGS"):
|
||||
warnings.showwarning = warn_with_traceback
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
data = data.model_dump_json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
|
|
@ -108,72 +115,20 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
|||
)
|
||||
|
||||
|
||||
async def passthrough(
|
||||
request: Request,
|
||||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
await start_trace(request.path, {"downstream_url": downstream_url})
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
||||
content = await request.body()
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
erred = False
|
||||
try:
|
||||
req = client.build_request(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
params=request.query_params,
|
||||
)
|
||||
response = await client.send(req, stream=True)
|
||||
|
||||
async def stream_response():
|
||||
async for chunk in response.aiter_raw(chunk_size=64):
|
||||
yield chunk
|
||||
|
||||
await response.aclose()
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.headers.get("content-type"),
|
||||
)
|
||||
|
||||
except httpx.ReadTimeout:
|
||||
erred = True
|
||||
return Response(content="Downstream server timed out", status_code=504)
|
||||
except httpx.NetworkError as e:
|
||||
erred = True
|
||||
return Response(content=f"Network error: {str(e)}", status_code=502)
|
||||
except httpx.TooManyRedirects:
|
||||
erred = True
|
||||
return Response(content="Too many redirects", status_code=502)
|
||||
except SSLError as e:
|
||||
erred = True
|
||||
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
||||
except httpx.HTTPStatusError as e:
|
||||
erred = True
|
||||
return Response(content=str(e), status_code=e.response.status_code)
|
||||
except Exception as e:
|
||||
erred = True
|
||||
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
||||
finally:
|
||||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
def handle_sigint(app, *args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||
|
||||
async def run_shutdown():
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
print(f"Shutting down {impl}")
|
||||
await impl.shutdown()
|
||||
|
||||
asyncio.run(run_shutdown())
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for task in asyncio.all_tasks(loop):
|
||||
task.cancel()
|
||||
|
||||
loop.stop()
|
||||
|
||||
|
||||
|
|
@ -182,76 +137,57 @@ async def lifespan(app: FastAPI):
|
|||
print("Starting up")
|
||||
yield
|
||||
print("Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
async def endpoint(request: Request):
|
||||
return await passthrough(request, downstream_url, downstream_headers)
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
|
||||
return endpoint
|
||||
|
||||
async def maybe_await(value):
|
||||
if inspect.iscoroutine(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints.get("return")
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
# is going to produce. /chat_completion can produce a streaming or
|
||||
# non-streaming response depending on if request.stream is True / False.
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
try:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
if asyncio.iscoroutinefunction(func)
|
||||
else func(**kwargs)
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
new_params = [
|
||||
|
|
@ -275,54 +211,118 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
return endpoint
|
||||
|
||||
|
||||
def main(
|
||||
yaml_config: str = "llamastack-run.yaml",
|
||||
port: int = 5000,
|
||||
disable_ipv6: bool = False,
|
||||
):
|
||||
with open(yaml_config, "r") as fp:
|
||||
config = StackRunConfig(**yaml.safe_load(fp))
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
app = FastAPI()
|
||||
async def __call__(self, scope, receive, send):
|
||||
path = scope["path"]
|
||||
await start_trace(path, {"location": "server"})
|
||||
try:
|
||||
return await self.app(scope, receive, send)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
def main():
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
"--yaml-config",
|
||||
help="Path to YAML configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template",
|
||||
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=5000, help="Port to listen on")
|
||||
parser.add_argument(
|
||||
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
print(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
print(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.yaml_config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
config_file = Path(args.yaml_config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
print(f"Using config file: {config_file}")
|
||||
elif args.template:
|
||||
config_file = (
|
||||
Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
print(f"Using template {args.template} config file: {config_file}")
|
||||
else:
|
||||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
with open(config_file, "r") as fp:
|
||||
config = replace_env_vars(yaml.safe_load(fp))
|
||||
config = StackRunConfig(**config)
|
||||
|
||||
print("Run configuration:")
|
||||
print(yaml.dump(config.model_dump(), indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError:
|
||||
sys.exit(1)
|
||||
|
||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(ConsoleTelemetryImpl(ConsoleConfig()))
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
if config.apis_to_serve:
|
||||
apis_to_serve = set(config.apis_to_serve)
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
apis_to_serve.add(Api.inspect)
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
# if we do not serve the corresponding router API, we should not serve the routing table API
|
||||
if inf.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
provider_spec = specs[api]
|
||||
if (
|
||||
isinstance(provider_spec, RemoteProviderSpec)
|
||||
and provider_spec.adapter is None
|
||||
):
|
||||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, module="pydantic._internal._fields"
|
||||
)
|
||||
else:
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(
|
||||
f"Could not find method {endpoint.name} on {impl}!!"
|
||||
)
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
|
|
@ -337,15 +337,18 @@ def main(
|
|||
print("")
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
|
||||
import uvicorn
|
||||
|
||||
# FYI this does not do hot-reloads
|
||||
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
||||
print(f"Listening on {listen_host}:{port}")
|
||||
uvicorn.run(app, host=listen_host, port=port)
|
||||
|
||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
||||
print(f"Listening on {listen_host}:{args.port}")
|
||||
uvicorn.run(app, host=listen_host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
main()
|
||||
|
|
|
|||
203
llama_stack/distribution/stack.py
Normal file
203
llama_stack/distribution/stack.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.batch_inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from llama_stack.apis.post_training import * # noqa: F403
|
||||
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LLAMA_STACK_API_VERSION = "alpha"
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
MemoryBanks,
|
||||
Inference,
|
||||
BatchInference,
|
||||
Agents,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
Telemetry,
|
||||
PostTraining,
|
||||
Memory,
|
||||
Eval,
|
||||
EvalTasks,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
DatasetIO,
|
||||
Models,
|
||||
Shields,
|
||||
Inspect,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
RESOURCES = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
("shields", Api.shields, "register_shield", "list_shields"),
|
||||
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||
(
|
||||
"scoring_fns",
|
||||
Api.scoring_functions,
|
||||
"register_scoring_function",
|
||||
"list_scoring_functions",
|
||||
),
|
||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||
]
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
if api not in impls:
|
||||
continue
|
||||
|
||||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
await method(**obj.model_dump())
|
||||
|
||||
method = getattr(impls[api], list_method)
|
||||
for obj in await method():
|
||||
log.info(
|
||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||
)
|
||||
|
||||
log.info("")
|
||||
|
||||
|
||||
class EnvVarError(Exception):
|
||||
def __init__(self, var_name: str, path: str = ""):
|
||||
self.var_name = var_name
|
||||
self.path = path
|
||||
super().__init__(
|
||||
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
|
||||
)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
for k, v in config.items():
|
||||
try:
|
||||
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
||||
elif isinstance(config, list):
|
||||
result = []
|
||||
for i, v in enumerate(config):
|
||||
try:
|
||||
result.append(replace_env_vars(v, f"{path}[{i}]"))
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
return result
|
||||
|
||||
elif isinstance(config, str):
|
||||
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
|
||||
|
||||
def get_env_var(match):
|
||||
env_var = match.group(1)
|
||||
default_val = match.group(2)
|
||||
|
||||
value = os.environ.get(env_var)
|
||||
if not value:
|
||||
if default_val is None:
|
||||
raise EnvVarError(env_var, path)
|
||||
else:
|
||||
value = default_val
|
||||
|
||||
# expand "~" from the values
|
||||
return os.path.expanduser(value)
|
||||
|
||||
try:
|
||||
return re.sub(pattern, get_env_var, config)
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||
"""Validate and split an environment variable key-value pair."""
|
||||
try:
|
||||
key, value = env_pair.split("=", 1)
|
||||
key = key.strip()
|
||||
if not key:
|
||||
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
||||
if not all(c.isalnum() or c == "_" for c in key):
|
||||
raise ValueError(
|
||||
f"Key must contain only alphanumeric characters and underscores: {key}"
|
||||
)
|
||||
return key, value
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
|
||||
) from e
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(
|
||||
run_config.metadata_store, run_config.image_name
|
||||
)
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(), dist_registry
|
||||
)
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
||||
|
||||
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
||||
template_path = pkg_resources.resource_filename(
|
||||
"llama_stack", f"templates/{template}/run.yaml"
|
||||
)
|
||||
|
||||
if not Path(template_path).exists():
|
||||
raise ValueError(f"Template '{template}' not found at {template_path}")
|
||||
|
||||
with open(template_path) as f:
|
||||
run_config = yaml.safe_load(f)
|
||||
|
||||
return StackRunConfig(**replace_env_vars(run_config))
|
||||
|
|
@ -33,10 +33,33 @@ shift
|
|||
port="$1"
|
||||
shift
|
||||
|
||||
# Process environment variables from --env arguments
|
||||
env_vars=""
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--env)
|
||||
|
||||
if [[ -n "$2" ]]; then
|
||||
# collect environment variables so we can set them after activating the conda env
|
||||
env_vars="$env_vars --env $2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda deactivate && conda activate "$env_name"
|
||||
|
||||
set -x
|
||||
$CONDA_PREFIX/bin/python \
|
||||
-m llama_stack.distribution.server.server \
|
||||
--yaml_config "$yaml_config" \
|
||||
--port "$port" "$@"
|
||||
--yaml-config "$yaml_config" \
|
||||
--port "$port" \
|
||||
$env_vars
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ DOCKER_BINARY=${DOCKER_BINARY:-docker}
|
|||
DOCKER_OPTS=${DOCKER_OPTS:-}
|
||||
LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
PYPI_VERSION=${PYPI_VERSION:-}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
|
|
@ -29,7 +31,7 @@ if [ $# -lt 3 ]; then
|
|||
fi
|
||||
|
||||
build_name="$1"
|
||||
docker_image="llamastack-$build_name"
|
||||
docker_image="localhost/distribution-$build_name"
|
||||
shift
|
||||
|
||||
yaml_config="$1"
|
||||
|
|
@ -38,6 +40,26 @@ shift
|
|||
port="$1"
|
||||
shift
|
||||
|
||||
# Process environment variables from --env arguments
|
||||
env_vars=""
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--env)
|
||||
echo "env = $2"
|
||||
if [[ -n "$2" ]]; then
|
||||
env_vars="$env_vars -e $2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
set -x
|
||||
|
||||
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
|
||||
|
|
@ -54,11 +76,21 @@ if [ -n "$LLAMA_CHECKPOINT_DIR" ]; then
|
|||
DOCKER_OPTS="$DOCKER_OPTS --gpus=all"
|
||||
fi
|
||||
|
||||
version_tag="latest"
|
||||
if [ -n "$PYPI_VERSION" ]; then
|
||||
version_tag="$PYPI_VERSION"
|
||||
elif [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
version_tag="dev"
|
||||
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
version_tag="test-$TEST_PYPI_VERSION"
|
||||
fi
|
||||
|
||||
$DOCKER_BINARY run $DOCKER_OPTS -it \
|
||||
-p $port:$port \
|
||||
$env_vars \
|
||||
-v "$yaml_config:/app/config.yaml" \
|
||||
$mounts \
|
||||
$docker_image \
|
||||
$docker_image:$version_tag \
|
||||
python -m llama_stack.distribution.server.server \
|
||||
--yaml_config /app/config.yaml \
|
||||
--port $port "$@"
|
||||
--yaml-config /app/config.yaml \
|
||||
--port "$port"
|
||||
|
|
|
|||
7
llama_stack/distribution/store/__init__.py
Normal file
7
llama_stack/distribution/store/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .registry import * # noqa: F401 F403
|
||||
221
llama_stack/distribution/store/registry.py
Normal file
221
llama_stack/distribution/store/registry.py
Normal file
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, List, Optional, Protocol, Tuple
|
||||
|
||||
import pydantic
|
||||
|
||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
from llama_stack.providers.utils.kvstore import (
|
||||
KVStore,
|
||||
kvstore_impl,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
class DistributionRegistry(Protocol):
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
||||
|
||||
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
||||
|
||||
async def update(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider: ...
|
||||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
||||
|
||||
async def delete(self, type: str, identifier: str) -> None: ...
|
||||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v2"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
def _get_registry_key_range() -> Tuple[str, str]:
|
||||
"""Returns the start and end keys for the registry range query."""
|
||||
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
|
||||
return start_key, f"{start_key}\xff"
|
||||
|
||||
|
||||
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
|
||||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||
all_objects = []
|
||||
for value in values:
|
||||
obj = pydantic.parse_obj_as(
|
||||
RoutableObjectWithProvider,
|
||||
json.loads(value),
|
||||
)
|
||||
all_objects.append(obj)
|
||||
return all_objects
|
||||
|
||||
|
||||
class DiskDistributionRegistry(DistributionRegistry):
|
||||
def __init__(self, kvstore: KVStore):
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
def get_cached(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
# Disk registry does not have a cache
|
||||
raise NotImplementedError("Disk registry does not have a cache")
|
||||
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||
start_key, end_key = _get_registry_key_range()
|
||||
values = await self.kvstore.range(start_key, end_key)
|
||||
return _parse_registry_values(values)
|
||||
|
||||
async def get(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
json_str = await self.kvstore.get(
|
||||
KEY_FORMAT.format(type=type, identifier=identifier)
|
||||
)
|
||||
if not json_str:
|
||||
return None
|
||||
|
||||
objects_data = json.loads(json_str)
|
||||
# Return only the first object if any exist
|
||||
if objects_data:
|
||||
return pydantic.parse_obj_as(
|
||||
RoutableObjectWithProvider,
|
||||
json.loads(objects_data),
|
||||
)
|
||||
return None
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||
obj.model_dump_json(),
|
||||
)
|
||||
return obj
|
||||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||
existing_obj = await self.get(obj.type, obj.identifier)
|
||||
# dont register if the object's providerid already exists
|
||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||
return False
|
||||
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||
obj.model_dump_json(),
|
||||
)
|
||||
return True
|
||||
|
||||
async def delete(self, type: str, identifier: str) -> None:
|
||||
await self.kvstore.delete(KEY_FORMAT.format(type=type, identifier=identifier))
|
||||
|
||||
|
||||
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||
def __init__(self, kvstore: KVStore):
|
||||
super().__init__(kvstore)
|
||||
self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {}
|
||||
self._initialized = False
|
||||
self._initialize_lock = asyncio.Lock()
|
||||
self._cache_lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _locked_cache(self):
|
||||
"""Context manager for safely accessing the cache with a lock."""
|
||||
async with self._cache_lock:
|
||||
yield self.cache
|
||||
|
||||
async def _ensure_initialized(self):
|
||||
"""Ensures the registry is initialized before operations."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
async with self._initialize_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
start_key, end_key = _get_registry_key_range()
|
||||
values = await self.kvstore.range(start_key, end_key)
|
||||
objects = _parse_registry_values(values)
|
||||
|
||||
async with self._locked_cache() as cache:
|
||||
for obj in objects:
|
||||
cache_key = (obj.type, obj.identifier)
|
||||
cache[cache_key] = obj
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await self._ensure_initialized()
|
||||
|
||||
def get_cached(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
return self.cache.get((type, identifier), None)
|
||||
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||
await self._ensure_initialized()
|
||||
async with self._locked_cache() as cache:
|
||||
return list(cache.values())
|
||||
|
||||
async def get(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
await self._ensure_initialized()
|
||||
cache_key = (type, identifier)
|
||||
|
||||
async with self._locked_cache() as cache:
|
||||
return cache.get(cache_key, None)
|
||||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||
await self._ensure_initialized()
|
||||
success = await super().register(obj)
|
||||
|
||||
if success:
|
||||
cache_key = (obj.type, obj.identifier)
|
||||
async with self._locked_cache() as cache:
|
||||
cache[cache_key] = obj
|
||||
|
||||
return success
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await super().update(obj)
|
||||
cache_key = (obj.type, obj.identifier)
|
||||
async with self._locked_cache() as cache:
|
||||
cache[cache_key] = obj
|
||||
return obj
|
||||
|
||||
async def delete(self, type: str, identifier: str) -> None:
|
||||
await super().delete(type, identifier)
|
||||
cache_key = (type, identifier)
|
||||
async with self._locked_cache() as cache:
|
||||
if cache_key in cache:
|
||||
del cache[cache_key]
|
||||
|
||||
|
||||
async def create_dist_registry(
|
||||
metadata_store: Optional[KVStoreConfig],
|
||||
image_name: str,
|
||||
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||
# instantiate kvstore for storing and retrieving distribution metadata
|
||||
if metadata_store:
|
||||
dist_kvstore = await kvstore_impl(metadata_store)
|
||||
else:
|
||||
dist_kvstore = await kvstore_impl(
|
||||
SqliteKVStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
|
||||
)
|
||||
)
|
||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||
await dist_registry.initialize()
|
||||
return dist_registry, dist_kvstore
|
||||
215
llama_stack/distribution/store/tests/test_registry.py
Normal file
215
llama_stack/distribution/store/tests/test_registry.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from llama_stack.distribution.store import * # noqa F403
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.memory_banks import VectorMemoryBank
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
from llama_stack.distribution.datatypes import * # noqa F403
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
|
||||
if os.path.exists(config.db_path):
|
||||
os.remove(config.db_path)
|
||||
return config
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def registry(config):
|
||||
registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def cached_registry(config):
|
||||
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_bank():
|
||||
return VectorMemoryBank(
|
||||
identifier="test_bank",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
provider_resource_id="test_bank",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model():
|
||||
return Model(
|
||||
identifier="test_model",
|
||||
provider_resource_id="test_model",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_initialization(registry):
|
||||
# Test empty registry
|
||||
results = await registry.get("nonexistent", "nonexistent")
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_registration(registry, sample_bank, sample_model):
|
||||
print(f"Registering {sample_bank}")
|
||||
await registry.register(sample_bank)
|
||||
print(f"Registering {sample_model}")
|
||||
await registry.register(sample_model)
|
||||
print("Getting bank")
|
||||
results = await registry.get("memory_bank", "test_bank")
|
||||
assert len(results) == 1
|
||||
result_bank = results[0]
|
||||
assert result_bank.identifier == sample_bank.identifier
|
||||
assert result_bank.embedding_model == sample_bank.embedding_model
|
||||
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
|
||||
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
||||
assert result_bank.provider_id == sample_bank.provider_id
|
||||
|
||||
results = await registry.get("model", "test_model")
|
||||
assert len(results) == 1
|
||||
result_model = results[0]
|
||||
assert result_model.identifier == sample_model.identifier
|
||||
assert result_model.provider_id == sample_model.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_initialization(config, sample_bank, sample_model):
|
||||
# First populate the disk registry
|
||||
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await disk_registry.initialize()
|
||||
await disk_registry.register(sample_bank)
|
||||
await disk_registry.register(sample_model)
|
||||
|
||||
# Test cached version loads from disk
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
results = await cached_registry.get("memory_bank", "test_bank")
|
||||
assert len(results) == 1
|
||||
result_bank = results[0]
|
||||
assert result_bank.identifier == sample_bank.identifier
|
||||
assert result_bank.embedding_model == sample_bank.embedding_model
|
||||
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
|
||||
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
||||
assert result_bank.provider_id == sample_bank.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_updates(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
new_bank = VectorMemoryBank(
|
||||
identifier="test_bank_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=256,
|
||||
overlap_size_in_tokens=32,
|
||||
provider_resource_id="test_bank_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_registry.register(new_bank)
|
||||
|
||||
# Verify in cache
|
||||
results = await cached_registry.get("memory_bank", "test_bank_2")
|
||||
assert len(results) == 1
|
||||
result_bank = results[0]
|
||||
assert result_bank.identifier == new_bank.identifier
|
||||
assert result_bank.provider_id == new_bank.provider_id
|
||||
|
||||
# Verify persisted to disk
|
||||
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await new_registry.initialize()
|
||||
results = await new_registry.get("memory_bank", "test_bank_2")
|
||||
assert len(results) == 1
|
||||
result_bank = results[0]
|
||||
assert result_bank.identifier == new_bank.identifier
|
||||
assert result_bank.provider_id == new_bank.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_provider_registration(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
original_bank = VectorMemoryBank(
|
||||
identifier="test_bank_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=256,
|
||||
overlap_size_in_tokens=32,
|
||||
provider_resource_id="test_bank_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_registry.register(original_bank)
|
||||
|
||||
duplicate_bank = VectorMemoryBank(
|
||||
identifier="test_bank_2",
|
||||
embedding_model="different-model",
|
||||
chunk_size_in_tokens=128,
|
||||
overlap_size_in_tokens=16,
|
||||
provider_resource_id="test_bank_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
await cached_registry.register(duplicate_bank)
|
||||
|
||||
results = await cached_registry.get("memory_bank", "test_bank_2")
|
||||
assert len(results) == 1 # Still only one result
|
||||
assert (
|
||||
results[0].embedding_model == original_bank.embedding_model
|
||||
) # Original values preserved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_objects(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
# Create multiple test banks
|
||||
test_banks = [
|
||||
VectorMemoryBank(
|
||||
identifier=f"test_bank_{i}",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=256,
|
||||
overlap_size_in_tokens=32,
|
||||
provider_resource_id=f"test_bank_{i}",
|
||||
provider_id=f"provider_{i}",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Register all banks
|
||||
for bank in test_banks:
|
||||
await cached_registry.register(bank)
|
||||
|
||||
# Test get_all retrieval
|
||||
all_results = await cached_registry.get_all()
|
||||
assert len(all_results) == 3
|
||||
|
||||
# Verify each bank was stored correctly
|
||||
for original_bank in test_banks:
|
||||
matching_banks = [
|
||||
b for b in all_results if b.identifier == original_bank.identifier
|
||||
]
|
||||
assert len(matching_banks) == 1
|
||||
stored_bank = matching_banks[0]
|
||||
assert stored_bank.embedding_model == original_bank.embedding_model
|
||||
assert stored_bank.provider_id == original_bank.provider_id
|
||||
assert stored_bank.chunk_size_in_tokens == original_bank.chunk_size_in_tokens
|
||||
assert (
|
||||
stored_bank.overlap_size_in_tokens == original_bank.overlap_size_in_tokens
|
||||
)
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
name: local-cpu
|
||||
distribution_spec:
|
||||
description: remote inference + local safety/agents/memory
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- remote::ollama
|
||||
- remote::tgi
|
||||
- remote::together
|
||||
- remote::fireworks
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
memory: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: docker
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
built_at: '2024-09-30T09:04:30.533391'
|
||||
image_name: local-cpu
|
||||
docker_image: local-cpu
|
||||
conda_env: null
|
||||
apis_to_serve:
|
||||
- agents
|
||||
- inference
|
||||
- models
|
||||
- memory
|
||||
- safety
|
||||
- shields
|
||||
- memory_banks
|
||||
api_providers:
|
||||
inference:
|
||||
providers:
|
||||
- remote::ollama
|
||||
safety:
|
||||
providers:
|
||||
- meta-reference
|
||||
agents:
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
memory:
|
||||
providers:
|
||||
- meta-reference
|
||||
telemetry:
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 6000
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield: null
|
||||
prompt_guard_shield: null
|
||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
||||
memory:
|
||||
- provider_type: meta-reference
|
||||
config: {}
|
||||
routing_key: vector
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
name: local-gpu
|
||||
distribution_spec:
|
||||
description: local meta reference
|
||||
docker_image: null
|
||||
providers:
|
||||
inference: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
memory: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: docker
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
built_at: '2024-09-30T09:00:56.693751'
|
||||
image_name: local-gpu
|
||||
docker_image: local-gpu
|
||||
conda_env: null
|
||||
apis_to_serve:
|
||||
- memory
|
||||
- inference
|
||||
- agents
|
||||
- shields
|
||||
- safety
|
||||
- models
|
||||
- memory_banks
|
||||
api_providers:
|
||||
inference:
|
||||
providers:
|
||||
- meta-reference
|
||||
safety:
|
||||
providers:
|
||||
- meta-reference
|
||||
agents:
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
memory:
|
||||
providers:
|
||||
- meta-reference
|
||||
telemetry:
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield: null
|
||||
prompt_guard_shield: null
|
||||
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
|
||||
memory:
|
||||
- provider_type: meta-reference
|
||||
config: {}
|
||||
routing_key: vector
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
name: local-bedrock-conda-example
|
||||
distribution_spec:
|
||||
description: Use Amazon Bedrock APIs.
|
||||
providers:
|
||||
inference: remote::bedrock
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
name: local
|
||||
distribution_spec:
|
||||
description: Use code from `llama_stack` itself to serve all llama stack APIs
|
||||
providers:
|
||||
inference: meta-reference
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
name: local-databricks
|
||||
distribution_spec:
|
||||
description: Use Databricks for running LLM inference
|
||||
providers:
|
||||
inference: remote::databricks
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
name: local-fireworks
|
||||
distribution_spec:
|
||||
description: Use Fireworks.ai for running LLM inference
|
||||
providers:
|
||||
inference: remote::fireworks
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
name: local-hf-endpoint
|
||||
distribution_spec:
|
||||
description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints."
|
||||
providers:
|
||||
inference: remote::hf::endpoint
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
name: local-hf-serverless
|
||||
distribution_spec:
|
||||
description: "Like local, but use Hugging Face Inference API (serverless) for running LLM inference.\nSee https://hf.co/docs/api-inference."
|
||||
providers:
|
||||
inference: remote::hf::serverless
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
name: local-ollama
|
||||
distribution_spec:
|
||||
description: Like local, but use ollama for running LLM inference
|
||||
providers:
|
||||
inference: remote::ollama
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
name: local-tgi
|
||||
distribution_spec:
|
||||
description: Like local, but use a TGI server for running LLM inference.
|
||||
providers:
|
||||
inference: remote::tgi
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
name: local-together
|
||||
distribution_spec:
|
||||
description: Use Together.ai for running LLM inference
|
||||
providers:
|
||||
inference: remote::together
|
||||
memory: meta-reference
|
||||
safety: remote::together
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
name: local-vllm
|
||||
distribution_spec:
|
||||
description: Like local, but use vLLM for running LLM inference
|
||||
providers:
|
||||
inference: vllm
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import pty
|
||||
import select
|
||||
|
|
@ -13,7 +14,7 @@ import subprocess
|
|||
import sys
|
||||
import termios
|
||||
|
||||
from termcolor import cprint
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# run a command in a pseudo-terminal, with interrupt handling,
|
||||
|
|
@ -29,7 +30,7 @@ def run_with_pty(command):
|
|||
def sigint_handler(signum, frame):
|
||||
nonlocal ctrl_c_pressed
|
||||
ctrl_c_pressed = True
|
||||
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
|
||||
log.info("\nCtrl-C detected. Aborting...")
|
||||
|
||||
try:
|
||||
# Set up the signal handler
|
||||
|
|
@ -100,6 +101,6 @@ def run_command(command):
|
|||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
if process.returncode != 0:
|
||||
print(f"Error: {error.decode('utf-8')}")
|
||||
log.error(f"Error: {error.decode('utf-8')}")
|
||||
sys.exit(1)
|
||||
return output.decode("utf-8")
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
||||
|
||||
def model_local_dir(descriptor: str) -> str:
|
||||
return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)
|
||||
return str(Path(DEFAULT_CHECKPOINT_DIR) / (descriptor.replace(":", "-")))
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
||||
|
|
@ -16,6 +17,8 @@ from pydantic_core import PydanticUndefinedType
|
|||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_list_of_primitives(field_type):
|
||||
"""Check if a field type is a List of primitive types."""
|
||||
|
|
@ -111,7 +114,7 @@ def prompt_for_discriminated_union(
|
|||
|
||||
if discriminator_value in type_map:
|
||||
chosen_type = type_map[discriminator_value]
|
||||
print(f"\nConfiguring {chosen_type.__name__}:")
|
||||
log.info(f"\nConfiguring {chosen_type.__name__}:")
|
||||
|
||||
if existing_value and (
|
||||
getattr(existing_value, discriminator) != discriminator_value
|
||||
|
|
@ -123,7 +126,7 @@ def prompt_for_discriminated_union(
|
|||
setattr(sub_config, discriminator, discriminator_value)
|
||||
return sub_config
|
||||
else:
|
||||
print(f"Invalid {discriminator}. Please try again.")
|
||||
log.error(f"Invalid {discriminator}. Please try again.")
|
||||
|
||||
|
||||
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
|
||||
|
|
@ -180,7 +183,7 @@ def prompt_for_config(
|
|||
config_data[field_name] = validated_value
|
||||
break
|
||||
except KeyError:
|
||||
print(
|
||||
log.error(
|
||||
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
|
||||
)
|
||||
continue
|
||||
|
|
@ -197,7 +200,7 @@ def prompt_for_config(
|
|||
config_data[field_name] = None
|
||||
continue
|
||||
nested_type = get_non_none_type(field_type)
|
||||
print(f"Entering sub-configuration for {field_name}:")
|
||||
log.info(f"Entering sub-configuration for {field_name}:")
|
||||
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
||||
elif is_optional(field_type) and is_discriminated_union(
|
||||
get_non_none_type(field_type)
|
||||
|
|
@ -213,7 +216,7 @@ def prompt_for_config(
|
|||
existing_value,
|
||||
)
|
||||
elif can_recurse(field_type):
|
||||
print(f"\nEntering sub-configuration for {field_name}:")
|
||||
log.info(f"\nEntering sub-configuration for {field_name}:")
|
||||
config_data[field_name] = prompt_for_config(
|
||||
field_type,
|
||||
existing_value,
|
||||
|
|
@ -240,7 +243,7 @@ def prompt_for_config(
|
|||
config_data[field_name] = None
|
||||
break
|
||||
else:
|
||||
print("This field is required. Please provide a value.")
|
||||
log.error("This field is required. Please provide a value.")
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
|
|
@ -264,12 +267,12 @@ def prompt_for_config(
|
|||
value = [element_type(item) for item in value]
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
log.error(
|
||||
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
|
||||
)
|
||||
continue
|
||||
except ValueError as e:
|
||||
print(f"{str(e)}")
|
||||
log.error(f"{str(e)}")
|
||||
continue
|
||||
|
||||
elif get_origin(field_type) is dict:
|
||||
|
|
@ -281,7 +284,7 @@ def prompt_for_config(
|
|||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
log.error(
|
||||
"Invalid JSON. Please enter a valid JSON-encoded dict."
|
||||
)
|
||||
continue
|
||||
|
|
@ -298,7 +301,7 @@ def prompt_for_config(
|
|||
value = field_type(user_input)
|
||||
|
||||
except ValueError:
|
||||
print(
|
||||
log.error(
|
||||
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
|
||||
)
|
||||
continue
|
||||
|
|
@ -311,6 +314,6 @@ def prompt_for_config(
|
|||
config_data[field_name] = validated_value
|
||||
break
|
||||
except ValueError as e:
|
||||
print(f"Validation error: {str(e)}")
|
||||
log.error(f"Validation error: {str(e)}")
|
||||
|
||||
return config_type(**config_data)
|
||||
|
|
|
|||
|
|
@ -1,257 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
DATABRICKS_SUPPORTED_MODELS = {
|
||||
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
||||
}
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(Inference):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> OpenAI:
|
||||
return OpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_databricks_messages(self, messages: list[Message]) -> list:
|
||||
databricks_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
databricks_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return databricks_messages
|
||||
|
||||
def resolve_databricks_model(self, model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert (
|
||||
model is not None
|
||||
and model.descriptor(shorten_default_variant=True)
|
||||
in DATABRICKS_SUPPORTED_MODELS
|
||||
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(DATABRICKS_SUPPORTED_MODELS.keys())}"
|
||||
|
||||
return DATABRICKS_SUPPORTED_MODELS.get(
|
||||
model.descriptor(shorten_default_variant=True)
|
||||
)
|
||||
|
||||
def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
options = self.get_databricks_chat_options(request)
|
||||
databricks_model = self.resolve_databricks_model(request.model)
|
||||
|
||||
if not request.stream:
|
||||
|
||||
r = self.client.chat.completions.create(
|
||||
model=databricks_model,
|
||||
messages=self._messages_to_databricks_messages(messages),
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
|
||||
stop_reason = None
|
||||
if r.choices[0].finish_reason:
|
||||
if r.choices[0].finish_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
for chunk in self.client.chat.completions.create(
|
||||
model=databricks_model,
|
||||
messages=self._messages_to_databricks_messages(messages),
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if chunk.choices[0].finish_reason:
|
||||
if (
|
||||
stop_reason is None
|
||||
and chunk.choices[0].finish_reason == "stop"
|
||||
):
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif (
|
||||
stop_reason is None
|
||||
and chunk.choices[0].finish_reason == "length"
|
||||
):
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk.choices[0].delta.content
|
||||
|
||||
if text is None:
|
||||
continue
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
@ -1,247 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
FIREWORKS_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
||||
}
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> Fireworks:
|
||||
return Fireworks(api_key=self.config.api_key)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:
|
||||
fireworks_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
fireworks_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return fireworks_messages
|
||||
|
||||
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
# accumulate sampling params and other options to pass to fireworks
|
||||
options = self.get_fireworks_chat_options(request)
|
||||
fireworks_model = self.map_to_provider_model(request.model)
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat.completions.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r.choices[0].finish_reason:
|
||||
if r.choices[0].finish_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in self.client.chat.completions.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if chunk.choices[0].finish_reason:
|
||||
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif (
|
||||
stop_reason is None
|
||||
and chunk.choices[0].finish_reason == "length"
|
||||
):
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is None:
|
||||
continue
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
@ -1,266 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
OLLAMA_SUPPORTED_SKUS = {
|
||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
||||
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
||||
}
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
||||
def __init__(self, url: str) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
|
||||
)
|
||||
self.url = url
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
return AsyncClient(host=self.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print("Initializing Ollama, checking connectivity to server...")
|
||||
try:
|
||||
await self.client.ps()
|
||||
except httpx.ConnectError as e:
|
||||
raise RuntimeError(
|
||||
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||
) from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
|
||||
ollama_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
ollama_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return ollama_messages
|
||||
|
||||
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
if (
|
||||
request.sampling_params.repetition_penalty is not None
|
||||
and request.sampling_params.repetition_penalty != 1.0
|
||||
):
|
||||
options["repeat_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
# accumulate sampling params and other options to pass to ollama
|
||||
options = self.get_ollama_chat_options(request)
|
||||
ollama_model = self.map_to_provider_model(request.model)
|
||||
|
||||
res = await self.client.ps()
|
||||
need_model_pull = True
|
||||
for r in res["models"]:
|
||||
if ollama_model == r["model"]:
|
||||
need_model_pull = False
|
||||
break
|
||||
|
||||
if need_model_pull:
|
||||
print(f"Pulling model: {ollama_model}")
|
||||
status = await self.client.pull(ollama_model)
|
||||
assert (
|
||||
status["status"] == "success"
|
||||
), f"Failed to pull model {self.model} in ollama"
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat(
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(messages),
|
||||
stream=False,
|
||||
options=options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r["done"]:
|
||||
if r["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r["message"]["content"], stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
stream = await self.client.chat(
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(messages),
|
||||
stream=True,
|
||||
options=options,
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk["done"]:
|
||||
if stop_reason is None and chunk["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif stop_reason is None and chunk["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk["message"]["content"]
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
@ -1,260 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _HfAdapter(Inference, RoutableProvider):
|
||||
client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||
prompt = self.tokenizer.decode(model_input.tokens)
|
||||
|
||||
input_tokens = len(model_input.tokens)
|
||||
max_new_tokens = min(
|
||||
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||
self.max_tokens - input_tokens - 1,
|
||||
)
|
||||
|
||||
print(f"Calculated max_new_tokens: {max_new_tokens}")
|
||||
|
||||
options = self.get_chat_options(request)
|
||||
if not request.stream:
|
||||
response = await self.client.text_generation(
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if response.details.finish_reason:
|
||||
if response.details.finish_reason in ["stop", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif response.details.finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
response.generated_text,
|
||||
stop_reason,
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
tokens = []
|
||||
|
||||
async for response in await self.client.text_generation(
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
):
|
||||
token_result = response.token
|
||||
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.id)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
|
||||
|
||||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.model_id, token=config.api_token
|
||||
)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
|
||||
|
||||
class InferenceEndpointAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
||||
# Get the inference endpoint details
|
||||
api = HfApi(token=config.api_token)
|
||||
endpoint = api.get_inference_endpoint(config.endpoint_name)
|
||||
|
||||
# Wait for the endpoint to be ready (if not already)
|
||||
endpoint.wait(timeout=60)
|
||||
|
||||
# Initialize the adapter
|
||||
self.client = endpoint.async_client
|
||||
self.model_id = endpoint.repository
|
||||
self.max_tokens = int(
|
||||
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
|
||||
)
|
||||
|
|
@ -1,265 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from together import Together
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
TOGETHER_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||
}
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(
|
||||
Inference, NeedsRequestProviderData, RoutableProviderForModels
|
||||
):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> Together:
|
||||
return Together(api_key=self.config.api_key)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_together_messages(self, messages: list[Message]) -> list:
|
||||
together_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
together_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return together_messages
|
||||
|
||||
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
client = Together(api_key=together_api_key)
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
# accumulate sampling params and other options to pass to together
|
||||
options = self.get_together_chat_options(request)
|
||||
together_model = self.map_to_provider_model(request.model)
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
if not request.stream:
|
||||
# TODO: might need to add back an async here
|
||||
r = client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r.choices[0].finish_reason:
|
||||
if (
|
||||
r.choices[0].finish_reason == "stop"
|
||||
or r.choices[0].finish_reason == "eos"
|
||||
):
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
for chunk in client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if finish_reason := chunk.choices[0].finish_reason:
|
||||
if stop_reason is None and finish_reason in ["stop", "eos"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif stop_reason is None and finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is None:
|
||||
continue
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer += text
|
||||
continue
|
||||
|
||||
if ipython:
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
|
||||
buffer += text
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
buffer += text
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
from .config import WeaviateConfig
|
||||
|
||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
||||
from .weaviate import WeaviateMemoryAdapter
|
||||
|
||||
impl = WeaviateMemoryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class WeaviateRequestProviderData(BaseModel):
|
||||
# if there _is_ provider data, it must specify the API KEY
|
||||
# if you want it to be optional, use Optional[str]
|
||||
weaviate_api_key: str
|
||||
weaviate_cluster_url: str
|
||||
|
||||
@json_schema_type
|
||||
class WeaviateConfig(BaseModel):
|
||||
collection: str = Field(default="MemoryBank")
|
||||
|
|
@ -1,192 +0,0 @@
|
|||
import json
|
||||
import uuid
|
||||
from typing import List, Optional, Dict, Any
|
||||
from numpy.typing import NDArray
|
||||
|
||||
import weaviate
|
||||
import weaviate.classes as wvc
|
||||
from weaviate.classes.init import Auth
|
||||
|
||||
from llama_stack.apis.memory import *
|
||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(self, client: weaviate.Client, collection: str):
|
||||
self.client = client
|
||||
self.collection = collection
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
data_objects = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
||||
data_objects.append(wvc.data.DataObject(
|
||||
properties={
|
||||
"chunk_content": chunk,
|
||||
},
|
||||
vector = embeddings[i].tolist()
|
||||
))
|
||||
|
||||
# Inserting chunks into a prespecified Weaviate collection
|
||||
assert self.collection is not None, "Collection name must be specified"
|
||||
my_collection = self.client.collections.get(self.collection)
|
||||
|
||||
await my_collection.data.insert_many(data_objects)
|
||||
|
||||
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
assert self.collection is not None, "Collection name must be specified"
|
||||
|
||||
my_collection = self.client.collections.get(self.collection)
|
||||
|
||||
results = my_collection.query.near_vector(
|
||||
near_vector = embedding.tolist(),
|
||||
limit = k,
|
||||
return_meta_data = wvc.query.MetadataQuery(distance=True)
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc in results.objects:
|
||||
try:
|
||||
chunk = doc.properties["chunk_content"]
|
||||
chunks.append(chunk)
|
||||
scores.append(1.0 / doc.metadata.distance)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse document: {e}")
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateMemoryAdapter(Memory):
|
||||
def __init__(self, config: WeaviateConfig) -> None:
|
||||
self.config = config
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
request_provider_data = get_request_provider_data()
|
||||
|
||||
if request_provider_data is not None:
|
||||
assert isinstance(request_provider_data, WeaviateRequestProviderData)
|
||||
|
||||
# Connect to Weaviate Cloud
|
||||
return weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url = request_provider_data.weaviate_cluster_url,
|
||||
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.client = self._get_client()
|
||||
|
||||
# Create collection if it doesn't exist
|
||||
if not self.client.collections.exists(self.config.collection):
|
||||
self.client.collections.create(
|
||||
name = self.config.collection,
|
||||
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
wvc.config.Property(
|
||||
name="chunk_content",
|
||||
data_type=wvc.config.DataType.TEXT,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise RuntimeError("Could not connect to Weaviate server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client = self._get_client()
|
||||
|
||||
if self.client:
|
||||
self.client.close()
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
self.client = self._get_client()
|
||||
|
||||
# Store the bank as a new collection in Weaviate
|
||||
self.client.collections.create(
|
||||
name=bank_id
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(cleint = self.client, collection = bank_id),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||
if bank_index is None:
|
||||
return None
|
||||
return bank_index.bank
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
|
||||
self.client = self._get_client()
|
||||
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
collections = await self.client.collections.list_all().keys()
|
||||
|
||||
for collection in collections:
|
||||
if collection == bank_id:
|
||||
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(self.client, collection),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
||||
return None
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
return await index.query_documents(query, params)
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import traceback
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import boto3
|
||||
|
||||
from llama_stack.apis.safety import * # noqa
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SUPPORTED_SHIELD_TYPES = [
|
||||
"bedrock_guardrail",
|
||||
]
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, RoutableProvider):
|
||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||
if not config.aws_profile:
|
||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
print(f"initializing with profile --- > {self.config}")
|
||||
self.boto_client = boto3.Session(
|
||||
profile_name=self.config.aws_profile
|
||||
).client("bedrock-runtime")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SUPPORTED_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
if shield_type not in SUPPORTED_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
{
|
||||
"text": {
|
||||
"text": "Is the AB503 Product a better investment than the S&P 500?"
|
||||
}
|
||||
}
|
||||
]```
|
||||
However the incoming messages are of this type UserMessage(content=....) coming from
|
||||
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
|
||||
|
||||
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"run_shield::{params}::messages={messages}")
|
||||
if "guardrailIdentifier" not in params:
|
||||
raise RuntimeError(
|
||||
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
||||
)
|
||||
|
||||
if "guardrailVersion" not in params:
|
||||
raise RuntimeError(
|
||||
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
||||
)
|
||||
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(
|
||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||
)
|
||||
|
||||
response = self.boto_client.apply_guardrail(
|
||||
guardrailIdentifier=params.get("guardrailIdentifier"),
|
||||
guardrailVersion=params.get("guardrailVersion"),
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
content=content_messages,
|
||||
)
|
||||
logger.debug(f"run_shield:: response: {response}::")
|
||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
for output in response["outputs"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
user_message = output["text"]
|
||||
for assessment in response["assessments"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
return SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
error_str = traceback.format_exc()
|
||||
logger.error(
|
||||
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BedrockSafetyConfig(BaseModel):
|
||||
"""Configuration information for a guardrail that you want to use in the request."""
|
||||
|
||||
aws_profile: str = Field(
|
||||
default="default",
|
||||
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
||||
)
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TogetherProviderDataValidator(BaseModel):
|
||||
together_api_key: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherSafetyConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Together AI API Key (default for the distribution, if any)",
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue