basic RAG seems to work

This commit is contained in:
Ashwin Bharambe 2024-08-24 23:36:58 -07:00
parent 830252257b
commit 58e2feceb0
3 changed files with 96 additions and 44 deletions

View file

@ -1,4 +1,8 @@
import textwrap
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
@ -41,8 +45,21 @@ def prepare_messages(request: ChatCompletionRequest) -> List[Message]:
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
sys_content += existing_system_message.content
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
messages.append(SystemMessage(content=sys_content))