[API Updates] Model / shield / memory-bank routing + agent persistence + support for private headers (#92)

This is yet another of those large PRs (hopefully we will have less and less of them as things mature fast). This one introduces substantial improvements and some simplifications to the stack.

Most important bits:

* Agents reference implementation now has support for session / turn persistence. The default implementation uses sqlite but there's also support for using Redis.

* We have re-architected the structure of the Stack APIs to allow for more flexible routing. The motivating use cases are:
  - routing model A to ollama and model B to a remote provider like Together
  - routing shield A to local impl while shield B to a remote provider like Bedrock
  - routing a vector memory bank to Weaviate while routing a keyvalue memory bank to Redis

* Support for provider specific parameters to be passed from the clients. A client can pass data using `x_llamastack_provider_data` parameter which can be type-checked and provided to the Adapter implementations.
This commit is contained in:
Ashwin Bharambe 2024-09-23 14:22:22 -07:00 committed by GitHub
parent 8bf8c07eb3
commit ec4fc800cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
130 changed files with 9701 additions and 11227 deletions

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleAgentsImpl
impl = SampleAgentsImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SampleConfig
from llama_stack.apis.agents import * # noqa: F403
class SampleAgentsImpl(Agents):
def __init__(self, config: SampleConfig):
self.config = config
async def initialize(self):
pass

View file

@ -6,14 +6,14 @@
from typing import AsyncGenerator
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from fireworks.client import Fireworks
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
@ -42,7 +42,14 @@ class FireworksInferenceAdapter(Inference):
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:

View file

@ -38,6 +38,7 @@ class OllamaInferenceAdapter(Inference):
return AsyncClient(host=self.url)
async def initialize(self) -> None:
print("Initializing Ollama, checking connectivity to server...")
try:
await self.client.ps()
except httpx.ConnectError as e:
@ -48,7 +49,14 @@ class OllamaInferenceAdapter(Inference):
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleInferenceImpl
impl = SampleInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SampleConfig
from llama_stack.apis.inference import * # noqa: F403
class SampleInferenceImpl(Inference):
def __init__(self, config: SampleConfig):
self.config = config
async def initialize(self):
pass

View file

@ -54,7 +54,14 @@ class TGIAdapter(Inference):
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def get_chat_options(self, request: ChatCompletionRequest) -> dict:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import TogetherImplConfig
from .config import TogetherImplConfig, TogetherHeaderExtractor
async def get_adapter_impl(config: TogetherImplConfig, _deps):

View file

@ -4,9 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_models.schema_utils import json_schema_type
from llama_stack.distribution.request_headers import annotate_header
class TogetherHeaderExtractor(BaseModel):
api_key: annotate_header(
"X-LlamaStack-Together-ApiKey", str, "The API Key for the request"
)
@json_schema_type
class TogetherImplConfig(BaseModel):

View file

@ -42,7 +42,14 @@ class TogetherInferenceAdapter(Inference):
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_together_messages(self, messages: list[Message]) -> list:

View file

@ -31,9 +31,6 @@ class ChromaIndex(EmbeddingIndex):
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
for i, chunk in enumerate(chunks):
print(f"Adding chunk #{i} tokens={chunk.token_count}")
await self.collection.add(
documents=[chunk.json() for chunk in chunks],
embeddings=embeddings,

View file

@ -80,7 +80,6 @@ class PGVectorIndex(EmbeddingIndex):
values = []
for i, chunk in enumerate(chunks):
print(f"Adding chunk #{i} tokens={chunk.token_count}")
values.append(
(
f"{chunk.document_id}:chunk-{i}",

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleMemoryImpl
impl = SampleMemoryImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SampleConfig
from llama_stack.apis.memory import * # noqa: F403
class SampleMemoryImpl(Memory):
def __init__(self, config: SampleConfig):
self.config = config
async def initialize(self):
pass

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleSafetyImpl
impl = SampleSafetyImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SampleConfig
from llama_stack.apis.safety import * # noqa: F403
class SampleSafetyImpl(Safety):
def __init__(self, config: SampleConfig):
self.config = config
async def initialize(self):
pass

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import OpenTelemetryConfig
async def get_adapter_impl(config: OpenTelemetryConfig, _deps):
from .opentelemetry import OpenTelemetryAdapter
impl = OpenTelemetryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class OpenTelemetryConfig(BaseModel):
jaeger_host: str = "localhost"
jaeger_port: int = 6831

View file

@ -0,0 +1,201 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from opentelemetry import metrics, trace
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import (
ConsoleMetricExporter,
PeriodicExportingMetricReader,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig):
self.config = config
self.resource = Resource.create(
{ResourceAttributes.SERVICE_NAME: "foobar-service"}
)
# Set up tracing with Jaeger exporter
jaeger_exporter = JaegerExporter(
agent_host_name=self.config.jaeger_host,
agent_port=self.config.jaeger_port,
)
trace_provider = TracerProvider(resource=self.resource)
trace_processor = BatchSpanProcessor(jaeger_exporter)
trace_provider.add_span_processor(trace_processor)
trace.set_tracer_provider(trace_provider)
self.tracer = trace.get_tracer(__name__)
# Set up metrics
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter())
metric_provider = MeterProvider(
resource=self.resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
async def log_event(self, event: Event) -> None:
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event)
elif isinstance(event, MetricEvent):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event)
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
span = trace.get_current_span()
span.add_event(
name=event.message,
attributes={"severity": event.severity.value, **event.attributes},
timestamp=event.timestamp,
)
def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int):
self.meter.create_counter(
name=event.metric,
unit=event.unit,
description=f"Counter for {event.metric}",
).add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
self.meter.create_gauge(
name=event.metric,
unit=event.unit,
description=f"Gauge for {event.metric}",
).set(event.value, attributes=event.attributes)
def _log_structured(self, event: StructuredLogEvent) -> None:
if isinstance(event.payload, SpanStartPayload):
context = trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.span_id),
is_remote=True,
)
)
)
span = self.tracer.start_span(
name=event.payload.name,
kind=trace.SpanKind.INTERNAL,
context=context,
attributes=event.attributes,
)
if event.payload.parent_span_id:
span.set_parent(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.payload.parent_span_id),
is_remote=True,
)
)
elif isinstance(event.payload, SpanEndPayload):
span = trace.get_current_span()
span.set_status(
trace.Status(
trace.StatusCode.OK
if event.payload.status == SpanStatus.OK
else trace.StatusCode.ERROR
)
)
span.end(end_time=event.timestamp)
async def get_trace(self, trace_id: str) -> Trace:
# we need to look up the root span id
raise NotImplementedError("not yet no")
# Usage example
async def main():
telemetry = OpenTelemetryTelemetry("my-service")
await telemetry.initialize()
# Log an unstructured event
await telemetry.log_event(
UnstructuredLogEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
message="This is a log message",
severity=LogSeverity.INFO,
)
)
# Log a metric event
await telemetry.log_event(
MetricEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
metric="my_metric",
value=42,
unit="count",
)
)
# Log a structured event (span start)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanStartPayload(name="my_operation"),
)
)
# Log a structured event (span end)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanEndPayload(status=SpanStatus.OK),
)
)
await telemetry.shutdown()
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleTelemetryImpl
impl = SampleTelemetryImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SampleConfig
from llama_stack.apis.telemetry import * # noqa: F403
class SampleTelemetryImpl(Telemetry):
def __init__(self, config: SampleConfig):
self.config = config
async def initialize(self):
pass

View file

@ -8,18 +8,14 @@ from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceImplConfig
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]
):
from .agents import MetaReferenceAgentsImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgentsImpl(
config,
deps[Api.inference],

View file

@ -25,10 +25,21 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
from .tools.base import BaseTool
from .tools.builtin import interpret_content_as_attachment, SingleMessageBuiltinTool
from .tools.builtin import (
CodeInterpreterTool,
interpret_content_as_attachment,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import SafeTool
def make_random_string(length: int = 8):
@ -40,23 +51,44 @@ def make_random_string(length: int = 8):
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
agent_id: str,
agent_config: AgentConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool],
max_infer_iters: int = 10,
persistence_store: KVStore,
):
self.agent_id = agent_id
self.agent_config = agent_config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.max_infer_iters = max_infer_iters
self.tools_dict = {t.get_name(): t for t in builtin_tools}
self.storage = AgentPersistence(agent_id, persistence_store)
self.tempdir = tempfile.mkdtemp()
self.sessions = {}
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
tool = WolframAlphaTool(tool_defn.api_key)
elif isinstance(tool_defn, SearchToolDefinition):
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=self.tempdir)
else:
continue
builtin_tools.append(
SafeTool(
tool,
safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
self.tools_dict = {t.get_name(): t for t in builtin_tools}
ShieldRunnerMixin.__init__(
self,
@ -80,7 +112,6 @@ class ChatAgent(ShieldRunnerMixin):
msg.context = None
messages.append(msg)
# messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
@ -94,43 +125,35 @@ class ChatAgent(ShieldRunnerMixin):
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
if step.violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
content=violation.user_message,
stop_reason=StopReason.end_of_turn,
)
)
# print_dialog(messages)
return messages
def create_session(self, name: str) -> Session:
session_id = str(uuid.uuid4())
session = Session(
session_id=session_id,
session_name=name,
turns=[],
started_at=datetime.now(),
)
self.sessions[session_id] = session
return session
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert (
request.session_id in self.sessions
), f"Session {request.session_id} not found"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
session = self.sessions[request.session_id]
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if len(session.turns) == 0 and self.agent_config.instructions != "":
if len(turns) == 0 and self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(session.turns):
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages)
@ -148,7 +171,7 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
output_message = None
async for chunk in self.run(
session=session,
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
@ -187,7 +210,7 @@ class ChatAgent(ShieldRunnerMixin):
completed_at=datetime.now(),
steps=steps,
)
session.turns.append(turn)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -200,7 +223,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run(
self,
session: Session,
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
@ -212,7 +235,7 @@ class ChatAgent(ShieldRunnerMixin):
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
@ -221,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
session, turn_id, input_messages, attachments, sampling_params, stream
session_id, turn_id, input_messages, attachments, sampling_params, stream
):
if isinstance(res, bool):
return
@ -235,7 +258,7 @@ class ChatAgent(ShieldRunnerMixin):
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
@ -245,11 +268,12 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response
async def run_shields_wrapper(
@tracing.span("run_shields")
async def run_multiple_shields_wrapper(
self,
turn_id: str,
messages: List[Message],
shields: List[ShieldDefinition],
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
if len(shields) == 0:
@ -266,7 +290,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
await self.run_shields(messages, shields)
await self.run_multiple_shields(messages, shields)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
@ -276,7 +300,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=e.response,
violation=e.violation,
),
)
)
@ -295,12 +319,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
violation=None,
),
)
)
@ -308,7 +327,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _run(
self,
session: Session,
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
@ -332,9 +351,10 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
rag_context, bank_ids = await self._retrieve_context(
session, input_messages, attachments
)
with tracing.span("retrieve_rag_context"):
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -387,55 +407,57 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls = []
content = ""
stop_reason = None
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
delta = event.delta
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
with tracing.span("inference"):
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
tool_call_delta=delta,
delta = event.delta
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
tool_call_delta=delta,
)
)
)
)
elif isinstance(delta, str):
content += delta
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
elif isinstance(delta, str):
content += delta
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
)
)
)
)
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
if event.stop_reason is not None:
stop_reason = event.stop_reason
if event.stop_reason is not None:
stop_reason = event.stop_reason
stop_reason = stop_reason or StopReason.out_of_tokens
message = CompletionMessage(
@ -461,7 +483,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
if n_iter >= self.max_infer_iters:
if n_iter >= self.agent_config.max_infer_iters:
cprint("Done with MAX iterations, exiting.")
yield message
break
@ -512,14 +534,15 @@ class ChatAgent(ShieldRunnerMixin):
)
)
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
with tracing.span("tool_execution"):
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -550,12 +573,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
violation=None,
),
)
)
@ -569,7 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=e.response,
violation=e.violation,
),
)
)
@ -594,17 +612,25 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
if session.memory_bank is None:
session.memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session.session_id}",
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
if session_info.memory_bank_id is None:
memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session_id}",
config=VectorMemoryBankConfig(
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
bank_id = memory_bank.bank_id
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
return session.memory_bank
return bank_id
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
@ -619,7 +645,6 @@ class ChatAgent(ShieldRunnerMixin):
else:
return True
print(f"{enabled_tools=}")
return AgentTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
@ -630,7 +655,7 @@ class ChatAgent(ShieldRunnerMixin):
return None
async def _retrieve_context(
self, session: Session, messages: List[Message], attachments: List[Attachment]
self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
bank_ids = []
@ -639,8 +664,8 @@ class ChatAgent(ShieldRunnerMixin):
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments:
bank = await self._ensure_memory_bank(session)
bank_ids.append(bank.bank_id)
bank_id = await self._ensure_memory_bank(session_id)
bank_ids.append(bank_id)
documents = [
MemoryBankDocument(
@ -651,9 +676,12 @@ class ChatAgent(ShieldRunnerMixin):
)
for a in attachments
]
await self.memory_api.insert_documents(bank.bank_id, documents)
elif session.memory_bank:
bank_ids.append(session.memory_bank.bank_id)
with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents)
else:
session_info = await self.storage.get_session_info(session_id)
if session_info.memory_bank_id:
bank_ids.append(session_info.memory_bank_id)
if not bank_ids:
# this can happen if the per-session memory bank is not yet populated

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import tempfile
import uuid
from typing import AsyncGenerator
@ -15,28 +14,19 @@ from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403
from .agent_instance import ChatAgent
from .config import MetaReferenceImplConfig
from .tools.builtin import (
CodeInterpreterTool,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import with_safety
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
logger = logging.getLogger()
logger.setLevel(logging.INFO)
AGENT_INSTANCES_BY_ID = {}
class MetaReferenceAgentsImpl(Agents):
def __init__(
self,
config: MetaReferenceImplConfig,
config: MetaReferenceAgentsImplConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
@ -45,9 +35,10 @@ class MetaReferenceAgentsImpl(Agents):
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.in_memory_store = InmemoryKVStoreImpl()
async def initialize(self) -> None:
pass
self.persistence_store = await kvstore_impl(self.config.persistence_store)
async def create_agent(
self,
@ -55,38 +46,46 @@ class MetaReferenceAgentsImpl(Agents):
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
tool = WolframAlphaTool(tool_defn.api_key)
elif isinstance(tool_defn, SearchToolDefinition):
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=tempfile.mkdtemp())
else:
continue
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.json(),
)
return AgentCreateResponse(
agent_id=agent_id,
)
builtin_tools.append(
with_safety(
tool,
self.safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
async def get_agent(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
if not agent_config:
raise ValueError(f"Could not find agent config for {agent_id}")
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
try:
agent_config = json.loads(agent_config)
except json.JSONDecodeError as e:
raise ValueError(
f"Could not JSON decode agent config for {agent_id}"
) from e
try:
agent_config = AgentConfig(**agent_config)
except Exception as e:
raise ValueError(
f"Could not validate(?) agent config for {agent_id}"
) from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,
builtin_tools=builtin_tools,
)
return AgentCreateResponse(
agent_id=agent_id,
persistence_store=(
self.persistence_store
if agent_config.enable_session_persistence
else self.in_memory_store
),
)
async def create_agent_session(
@ -94,12 +93,11 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]
agent = await self.get_agent(agent_id)
session = agent.create_session(session_name)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
session_id=session.session_id,
session_id=session_id,
)
async def create_agent_turn(
@ -115,6 +113,8 @@ class MetaReferenceAgentsImpl(Agents):
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
agent = await self.get_agent(agent_id)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -124,12 +124,5 @@ class MetaReferenceAgentsImpl(Agents):
stream=stream,
)
agent_id = request.agent_id
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]
assert (
request.session_id in agent.sessions
), f"Session {request.session_id} not found"
async for event in agent.create_and_execute_turn(request):
yield event

View file

@ -6,5 +6,8 @@
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStoreConfig
class MetaReferenceImplConfig(BaseModel): ...
class MetaReferenceAgentsImplConfig(BaseModel):
persistence_store: KVStoreConfig

View file

@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from datetime import datetime
from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
started_at: datetime
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore):
self.agent_id = agent_id
self.kvstore = kvstore
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(),
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
)
return session_id
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
if not value:
return None
return AgentSessionInfo(**json.loads(value))
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.json(),
)
async def get_session_turns(self, session_id: str) -> List[Turn]:
values = await self.kvstore.range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
turns = []
for value in values:
try:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
print(f"Error parsing turn: {e}")
continue
return turns

View file

@ -4,51 +4,48 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from llama_stack.apis.safety import (
OnViolationAction,
Safety,
ShieldDefinition,
ShieldResponse,
)
from llama_stack.apis.safety import * # noqa: F403
class SafetyException(Exception): # noqa: N818
def __init__(self, response: ShieldResponse):
self.response = response
super().__init__(response.violation_return_message)
def __init__(self, violation: SafetyViolation):
self.violation = violation
super().__init__(violation.user_message)
class ShieldRunnerMixin:
def __init__(
self,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
input_shields: List[str] = None,
output_shields: List[str] = None,
):
self.safety_api = safety_api
self.input_shields = input_shields
self.output_shields = output_shields
async def run_shields(
self, messages: List[Message], shields: List[ShieldDefinition]
) -> List[ShieldResponse]:
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
results = await self.safety_api.run_shields(
messages=messages,
shields=shields,
async def run_multiple_shields(
self, messages: List[Message], shields: List[str]
) -> None:
responses = await asyncio.gather(
*[
self.safety_api.run_shield(
shield_type=shield_type,
messages=messages,
)
for shield_type in shields
]
)
for shield, r in zip(shields, results):
if r.is_violation:
for shield, r in zip(shields, responses):
if r.violation:
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(r)
elif shield.on_violation_action == OnViolationAction.WARN:
@ -56,5 +53,3 @@ class ShieldRunnerMixin:
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
return results

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from typing import AsyncIterator, List, Optional, Union
from unittest.mock import MagicMock
import pytest
@ -79,10 +78,10 @@ class MockInferenceAPI:
class MockSafetyAPI:
async def run_shields(
self, messages: List[Message], shields: List[MagicMock]
) -> List[ShieldResponse]:
return [ShieldResponse(shield_type="mock_shield", is_violation=False)]
async def run_shield(
self, shield_type: str, messages: List[Message]
) -> RunShieldResponse:
return RunShieldResponse(violation=None)
class MockMemoryAPI:
@ -185,6 +184,7 @@ async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
# ),
],
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=[],
output_shields=[],
)
@ -221,13 +221,13 @@ async def test_chat_agent_create_and_execute_turn(chat_agent):
@pytest.mark.asyncio
async def test_run_shields_wrapper(chat_agent):
async def test_run_multiple_shields_wrapper(chat_agent):
messages = [UserMessage(content="Test message")]
shields = [ShieldDefinition(shield_type="test_shield")]
shields = ["test_shield"]
responses = [
chunk
async for chunk in chat_agent.run_shields_wrapper(
async for chunk in chat_agent.run_multiple_shields_wrapper(
turn_id="test_turn_id",
messages=messages,
shields=shields,

View file

@ -7,7 +7,7 @@
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, ShieldDefinition
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
@ -21,8 +21,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
self,
tool: BaseTool,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
input_shields: List[str] = None,
output_shields: List[str] = None,
):
self._tool = tool
ShieldRunnerMixin.__init__(
@ -30,29 +30,14 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
)
def get_name(self) -> str:
# return the name of the wrapped tool
return self._tool.get_name()
async def run(self, messages: List[Message]) -> List[Message]:
if self.input_shields:
await self.run_shields(messages, self.input_shields)
await self.run_multiple_shields(messages, self.input_shields)
# run the underlying tool
res = await self._tool.run(messages)
if self.output_shields:
await self.run_shields(messages, self.output_shields)
await self.run_multiple_shields(messages, self.output_shields)
return res
def with_safety(
tool: BaseTool,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
) -> SafeTool:
return SafeTool(
tool,
safety_api,
input_shields=input_shields,
output_shields=output_shields,
)

View file

@ -6,17 +6,14 @@
from typing import Optional
from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import all_registered_models, resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.inference import QuantizationConfig
@json_schema_type
class MetaReferenceImplConfig(BaseModel):
model: str = Field(
default="Meta-Llama3.1-8B-Instruct",
@ -34,6 +31,7 @@ class MetaReferenceImplConfig(BaseModel):
m.descriptor()
for m in all_registered_models()
if m.model_family == ModelFamily.llama3_1
or m.core_model_id == CoreModelId.llama_guard_3_8b
]
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)

View file

@ -57,7 +57,7 @@ class MetaReferenceInferenceImpl(Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tools: Optional[List[ToolDefinition]] = [],
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
@ -70,7 +70,7 @@ class MetaReferenceInferenceImpl(Inference):
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,

View file

@ -42,7 +42,6 @@ class FaissIndex(EmbeddingIndex):
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
logger.info(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
self.id_by_index[indexlen + i] = chunk.document_id
self.index.add(np.array(embeddings).astype(np.float32))

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Optional
from llama_models.sku_list import CoreModelId, safety_models
@ -11,6 +12,13 @@ from llama_models.sku_list import CoreModelId, safety_models
from pydantic import BaseModel, validator
class MetaReferenceShieldType(Enum):
llama_guard = "llama_guard"
code_scanner_guard = "code_scanner_guard"
injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
class LlamaGuardShieldConfig(BaseModel):
model: str = "Llama-Guard-3-8B"
excluded_categories: List[str] = []

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.safety import * # noqa
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from .config import MetaReferenceShieldType, SafetyConfig
from .config import SafetyConfig
from .shields import (
CodeScannerShield,
InjectionShield,
@ -19,7 +19,6 @@ from .shields import (
LlamaGuardShield,
PromptGuardShield,
ShieldBase,
ThirdPartyShield,
)
@ -50,46 +49,58 @@ class MetaReferenceSafetyImpl(Safety):
model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir)
async def run_shields(
async def run_shield(
self,
shield_type: str,
messages: List[Message],
shields: List[ShieldDefinition],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in shields]
available_shields = [v.value for v in MetaReferenceShieldType]
assert shield_type in available_shields, f"Unknown shield {shield_type}"
responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
return RunShieldResponse(responses=responses)
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages)
violation = None
if res.is_violation:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message=res.violation_return_message,
metadata={
"violation_type": res.violation_type,
},
)
def shield_type_equals(a: ShieldType, b: ShieldType):
return a == b or a == b.value
return RunShieldResponse(violation=violation)
def shield_config_to_shield(
sc: ShieldDefinition, safety_config: SafetyConfig
) -> ShieldBase:
if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard):
assert (
safety_config.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard):
return CodeScannerShield.instance()
elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield):
return ThirdPartyShield.instance()
else:
raise ValueError(f"Unknown shield type: {sc.shield_type}")
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config
if typ == MetaReferenceShieldType.llama_guard:
assert (
cfg.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
elif typ == MetaReferenceShieldType.jailbreak_shield:
assert (
cfg.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif typ == MetaReferenceShieldType.injection_shield:
assert (
cfg.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif typ == MetaReferenceShieldType.code_scanner_guard:
return CodeScannerShield.instance()
else:
raise ValueError(f"Unknown shield type: {typ}")

View file

@ -15,7 +15,6 @@ from .base import ( # noqa: F401
TextShield,
)
from .code_scanner import CodeScannerShield # noqa: F401
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import ( # noqa: F401
InjectionShield,

View file

@ -8,11 +8,26 @@ from abc import ABC, abstractmethod
from typing import List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from pydantic import BaseModel
from llama_stack.apis.safety import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
# TODO: clean this up; just remove this type completely
class ShieldResponse(BaseModel):
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None
# TODO: this is a caller / agent concern
class OnViolationAction(Enum):
IGNORE = 0
WARN = 1
RAISE = 2
class ShieldBase(ABC):
def __init__(
self,
@ -20,10 +35,6 @@ class ShieldBase(ABC):
):
self.on_violation_action = on_violation_action
@abstractmethod
def get_shield_type(self) -> ShieldType:
raise NotImplementedError()
@abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError()
@ -48,11 +59,6 @@ class TextShield(ShieldBase):
class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType:
return "dummy"
async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e
return ShieldResponse(
shield_type=BuiltinShield.third_party_shield, is_violation=False
)
return ShieldResponse(is_violation=False)

View file

@ -7,13 +7,9 @@
from termcolor import cprint
from .base import ShieldResponse, TextShield
from llama_stack.apis.safety import * # noqa: F403
class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard
async def run_impl(self, text: str) -> ShieldResponse:
from codeshield.cs import CodeShield
@ -21,7 +17,6 @@ class CodeScannerShield(TextShield):
result = await CodeShield.scan_code(text)
if result.is_insecure:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard,
is_violation=True,
violation_type=",".join(
[issue.pattern_id for issue in result.issues_found]
@ -29,6 +24,4 @@ class CodeScannerShield(TextShield):
violation_return_message="Sorry, I found security concerns in the code.",
)
else:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
)
return ShieldResponse(is_violation=False)

View file

@ -1,35 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import Message
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
ShieldBase,
ShieldResponse,
)
_INSTANCE = None
class ThirdPartyShield(ShieldBase):
@staticmethod
def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = ThirdPartyShield(on_violation_action)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
async def run(self, messages: List[Message]) -> ShieldResponse:
super.run() # will raise NotImplementedError

View file

@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_stack.apis.safety import * # noqa: F403
SAFE_RESPONSE = "safe"
_INSTANCE = None
@ -152,9 +152,6 @@ class LlamaGuardShield(ShieldBase):
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
@ -192,18 +189,13 @@ class LlamaGuardShield(ShieldBase):
def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(is_violation=False)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(is_violation=False)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
@ -213,12 +205,9 @@ class LlamaGuardShield(ShieldBase):
async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
else:

View file

@ -13,7 +13,6 @@ from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_stack.apis.safety import * # noqa: F403
class PromptGuardShield(TextShield):
@ -74,13 +73,6 @@ class PromptGuardShield(TextShield):
self.threshold = threshold
self.mode = mode
def get_shield_type(self) -> ShieldType:
return (
BuiltinShield.jailbreak_shield
if self.mode == self.Mode.JAILBREAK
else BuiltinShield.injection_shield
)
def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1])
@ -103,21 +95,18 @@ class PromptGuardShield(TextShield):
score_embedded + score_malicious > self.threshold
):
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=False,
)

View file

@ -6,7 +6,8 @@
from typing import List
from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.utils.kvstore import kvstore_dependencies
def available_providers() -> List[ProviderSpec]:
@ -19,15 +20,23 @@ def available_providers() -> List[ProviderSpec]:
"pillow",
"pandas",
"scikit-learn",
"torch",
"transformers",
],
]
+ kvstore_dependencies(),
module="llama_stack.providers.impls.meta_reference.agents",
config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceImplConfig",
config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceAgentsImplConfig",
api_dependencies=[
Api.inference,
Api.safety,
Api.memory,
],
),
remote_provider_spec(
api=Api.agents,
adapter=AdapterSpec(
adapter_id="sample",
pip_packages=[],
module="llama_stack.providers.adapters.agents.sample",
config_class="llama_stack.providers.adapters.agents.sample.SampleConfig",
),
),
]

View file

@ -26,6 +26,15 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceImplConfig",
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="sample",
pip_packages=[],
module="llama_stack.providers.adapters.inference.sample",
config_class="llama_stack.providers.adapters.inference.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
@ -63,6 +72,7 @@ def available_providers() -> List[ProviderSpec]:
],
module="llama_stack.providers.adapters.inference.together",
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
),
),
]

View file

@ -42,4 +42,13 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
),
),
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(
adapter_id="sample",
pip_packages=[],
module="llama_stack.providers.adapters.memory.sample",
config_class="llama_stack.providers.adapters.memory.sample.SampleConfig",
),
),
]

View file

@ -6,7 +6,7 @@
from typing import List
from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
@ -23,4 +23,13 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.impls.meta_reference.safety",
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_id="sample",
pip_packages=[],
module="llama_stack.providers.adapters.safety.sample",
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
),
),
]

View file

@ -18,4 +18,27 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.impls.meta_reference.telemetry",
config_class="llama_stack.providers.impls.meta_reference.telemetry.ConsoleConfig",
),
remote_provider_spec(
api=Api.telemetry,
adapter=AdapterSpec(
adapter_id="sample",
pip_packages=[],
module="llama_stack.providers.adapters.telemetry.sample",
config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.telemetry,
adapter=AdapterSpec(
adapter_id="opentelemetry-jaeger",
pip_packages=[
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-jaeger",
"opentelemetry-semantic-conventions",
],
module="llama_stack.providers.adapters.telemetry.opentelemetry",
config_class="llama_stack.providers.adapters.telemetry.opentelemetry.OpenTelemetryConfig",
),
),
]

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Tuple
from llama_stack.distribution.datatypes import Api
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
from .memory import MemoryRouterImpl
impl = MemoryRouterImpl(inner_impls, deps)
await impl.initialize()
return impl

View file

@ -1,91 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api
from llama_stack.apis.memory import * # noqa: F403
class MemoryRouterImpl(Memory):
"""Routes to an provider based on the memory bank type"""
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
deps: List[Api],
) -> None:
self.deps = deps
bank_types = [v.value for v in MemoryBankType]
self.providers = {}
for routing_key, provider_impl in inner_impls:
if routing_key not in bank_types:
raise ValueError(
f"Unknown routing key `{routing_key}` for memory bank type"
)
self.providers[routing_key] = provider_impl
self.bank_id_to_type = {}
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for p in self.providers.values():
await p.shutdown()
def get_provider(self, bank_type):
if bank_type not in self.providers:
raise ValueError(f"Memory bank type {bank_type} not supported")
return self.providers[bank_type]
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
provider = self.get_provider(config.type)
bank = await provider.create_memory_bank(name, config, url)
self.bank_id_to_type[bank.bank_id] = config.type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.get_memory_bank(bank_id)
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.insert_documents(bank_id, documents, ttl_seconds)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.query_documents(bank_id, query, params)

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .kvstore import * # noqa: F401, F403

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import List, Optional, Protocol
class KVStore(Protocol):
# TODO: make the value type bytes instead of str
async def set(
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None: ...
async def get(self, key: str) -> Optional[str]: ...
async def delete(self, key: str) -> None: ...
async def range(self, start_key: str, end_key: str) -> List[str]: ...

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class KVStoreType(Enum):
redis = "redis"
sqlite = "sqlite"
postgres = "postgres"
class CommonConfig(BaseModel):
namespace: Optional[str] = Field(
default=None,
description="All keys will be prefixed with this namespace",
)
class RedisKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value
host: str = "localhost"
port: int = 6379
class SqliteKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
db_path: str = Field(
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
description="File path for the sqlite database",
)
class PostgresKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
host: str = "localhost"
port: int = 5432
db: str = "llamastack"
user: str
password: Optional[str] = None
KVStoreConfig = Annotated[
Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig],
Field(discriminator="type", default=KVStoreType.sqlite.value),
]

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F403
from .config import * # noqa: F403
def kvstore_dependencies():
return ["aiosqlite", "psycopg2-binary", "redis"]
class InmemoryKVStoreImpl(KVStore):
def __init__(self):
self._store = {}
async def initialize(self) -> None:
pass
async def get(self, key: str) -> Optional[str]:
return self._store.get(key)
async def set(self, key: str, value: str) -> None:
self._store[key] = value
async def range(self, start_key: str, end_key: str) -> List[str]:
return [
self._store[key]
for key in self._store.keys()
if key >= start_key and key < end_key
]
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis.value:
from .redis import RedisKVStoreImpl
impl = RedisKVStoreImpl(config)
elif config.type == KVStoreType.sqlite.value:
from .sqlite import SqliteKVStoreImpl
impl = SqliteKVStoreImpl(config)
elif config.type == KVStoreType.postgres.value:
raise NotImplementedError()
else:
raise ValueError(f"Unknown kvstore type {config.type}")
await impl.initialize()
return impl

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .redis import RedisKVStoreImpl # noqa: F401

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import List, Optional
from redis.asyncio import Redis
from ..api import * # noqa: F403
from ..config import RedisKVStoreConfig
class RedisKVStoreImpl(KVStore):
def __init__(self, config: RedisKVStoreConfig):
self.config = config
async def initialize(self) -> None:
self.redis = Redis.from_url(self.config.url)
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None:
key = self._namespaced_key(key)
await self.redis.set(key, value)
if expiration:
await self.redis.expireat(key, expiration)
async def get(self, key: str) -> Optional[str]:
key = self._namespaced_key(key)
value = await self.redis.get(key)
if value is None:
return None
ttl = await self.redis.ttl(key)
return value
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> List[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
return await self.redis.zrangebylex(start_key, end_key)

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .sqlite import SqliteKVStoreImpl # noqa: F401

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class SqliteControlPlaneConfig(BaseModel):
db_path: str = Field(
description="File path for the sqlite database",
)
table_name: str = Field(
default="llamastack_control_plane",
description="Table into which all the keys will be placed",
)

View file

@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from datetime import datetime
from typing import List, Optional
import aiosqlite
from ..api import * # noqa: F403
from ..config import SqliteKVStoreConfig
class SqliteKVStoreImpl(KVStore):
def __init__(self, config: SqliteKVStoreConfig):
self.db_path = config.db_path
self.table_name = "kvstore"
async def initialize(self):
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
await db.commit()
async def set(
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
(key, value, expiration),
)
await db.commit()
async def get(self, key: str) -> Optional[str]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
) as cursor:
row = await cursor.fetchone()
if row is None:
return None
value, expiration = row
return value
async def delete(self, key: str) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit()
async def range(self, start_key: str, end_key: str) -> List[str]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
) as cursor:
result = []
async for row in cursor:
_, value, _ = row
result.append(value)
return result

View file

@ -16,6 +16,7 @@ import httpx
import numpy as np
from numpy.typing import NDArray
from pypdf import PdfReader
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
@ -160,6 +161,8 @@ class BankWithIndex:
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
)
if not chunks:
continue
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
await self.index.add_chunks(chunks, embeddings)

View file

@ -12,7 +12,7 @@ import threading
import uuid
from datetime import datetime
from functools import wraps
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List
from llama_stack.apis.telemetry import * # noqa: F403
@ -196,33 +196,40 @@ class TelemetryHandler(logging.Handler):
pass
def span(name: str, attributes: Dict[str, Any] = None):
def decorator(func):
class SpanContextManager:
def __init__(self, name: str, attributes: Dict[str, Any] = None):
self.name = name
self.attributes = attributes
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
async def __aenter__(self):
return self.__enter__()
async def __aexit__(self, exc_type, exc_value, traceback):
self.__exit__(exc_type, exc_value, traceback)
def __call__(self, func: Callable):
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = func(*args, **kwargs)
finally:
context.pop_span()
return result
with self:
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = await func(*args, **kwargs)
finally:
context.pop_span()
return result
async with self:
return await func(*args, **kwargs)
@wraps(func)
def wrapper(*args, **kwargs):
@ -233,4 +240,6 @@ def span(name: str, attributes: Dict[str, Any] = None):
return wrapper
return decorator
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)