Merge remote-tracking branch 'origin/main' into api_updates_1

This commit is contained in:
Ashwin Bharambe 2024-08-28 16:02:34 -07:00
commit d3965dd435
11 changed files with 428 additions and 3 deletions

View file

@ -1,3 +1,4 @@
include requirements.txt
include llama_toolchain/data/*.yaml
include llama_toolchain/distribution/*.sh
include llama_toolchain/cli/scripts/*.sh

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,38 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
if [ $# -eq 0 ]; then
echo "Please provide a URL as an argument."
exit 1
fi
URL=$1
HEADERS_FILE=$(mktemp)
curl -s -I "$URL" >"$HEADERS_FILE"
FILENAME=$(grep -i "x-manifold-obj-canonicalpath:" "$HEADERS_FILE" | sed -E 's/.*nodes\/[^\/]+\/(.+)/\1/' | tr -d "\r\n")
if [ -z "$FILENAME" ]; then
echo "Could not find the x-manifold-obj-canonicalpath header."
echo "HEADERS_FILE contents: "
cat "$HEADERS_FILE"
echo ""
exit 1
fi
echo "Downloading $FILENAME..."
curl -s -L -o "$FILENAME" "$URL"
echo "Installing $FILENAME..."
pip install "$FILENAME"
echo "Successfully installed $FILENAME"
rm -f "$FILENAME"

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import subprocess
import sys
def install_wheel_from_presigned():
file = "install-wheel-from-presigned.sh"
script_path = os.path.join(os.path.dirname(__file__), file)
try:
subprocess.run(["sh", script_path] + sys.argv[1:], check=True)
except Exception:
sys.exit(1)

View file

@ -38,6 +38,15 @@ def available_distribution_specs() -> List[DistributionSpec]:
Api.memory: "meta-reference-faiss",
},
),
DistributionSpec(
spec_id="remote-fireworks",
description="Use Fireworks.ai for running LLM inference",
provider_specs={
Api.inference: providers[Api.inference]["fireworks"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
]

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import FireworksImplConfig # noqa
from .fireworks import get_provider_impl # noqa

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class FireworksImplConfig(BaseModel):
url: str = Field(
default="https://api.fireworks.api/inference",
description="The URL for the Fireworks server",
)
api_key: str = Field(
default="",
description="The Fireworks.ai API Key",
)

View file

@ -0,0 +1,312 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import AsyncGenerator, Dict
import httpx
from llama_models.llama3.api.datatypes import (
BuiltinTool,
CompletionMessage,
Message,
StopReason,
ToolCall,
)
from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model
from fireworks.client import Fireworks
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Meta-Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Meta-Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
}
async def get_provider_impl(
config: FireworksImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance(
config, FireworksImplConfig
), f"Unexpected config type: {type(config)}"
impl = FireworksInference(config)
await impl.initialize()
return impl
class FireworksInference(Inference):
def __init__(self, config: FireworksImplConfig) -> None:
self.config = config
@property
def client(self) -> Fireworks:
return Fireworks(api_key=self.config.api_key)
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:
fireworks_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
fireworks_messages.append({"role": role, "content": message.content})
return fireworks_messages
def resolve_fireworks_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in FIREWORKS_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(FIREWORKS_SUPPORTED_MODELS.keys())}"
return FIREWORKS_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
# accumulate sampling params and other options to pass to fireworks
options = self.get_fireworks_chat_options(request)
fireworks_model = self.resolve_fireworks_model(request.model)
if not request.stream:
r = await self.client.chat.completions.acreate(
model=fireworks_model,
messages=self._messages_to_fireworks_messages(request.messages),
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if r.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = decode_assistant_message_from_content(
r.choices[0].message.content,
stop_reason,
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in self.client.chat.completions.acreate(
model=fireworks_model,
messages=self._messages_to_fireworks_messages(request.messages),
stream=True,
**options,
):
if chunk.choices[0].finish_reason:
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].delta.content
if text is None:
continue
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = decode_assistant_message_from_content(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
# TODO: Consolidate this with impl in llama-models
def decode_assistant_message_from_content(
content: str,
stop_reason: StopReason,
) -> CompletionMessage:
ipython = content.startswith("<|python_tag|>")
if ipython:
content = content[len("<|python_tag|>") :]
if content.endswith("<|eot_id|>"):
content = content[: -len("<|eot_id|>")]
stop_reason = StopReason.end_of_turn
elif content.endswith("<|eom_id|>"):
content = content[: -len("<|eom_id|>")]
stop_reason = StopReason.end_of_message
tool_name = None
tool_arguments = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
tool_name, query = builtin_tool_info
tool_arguments = {
"query": query,
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = []
if tool_name is not None and tool_arguments is not None:
call_id = str(uuid.uuid4())
tool_calls.append(
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
)
)
content = ""
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
return CompletionMessage(
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)

View file

@ -35,4 +35,13 @@ def available_inference_providers() -> List[ProviderSpec]:
module="llama_toolchain.inference.adapters.ollama",
),
),
InlineProviderSpec(
api=Api.inference,
provider_id="fireworks",
pip_packages=[
"fireworks-ai",
],
module="llama_toolchain.inference.fireworks",
config_class="llama_toolchain.inference.fireworks.FireworksImplConfig",
),
]

View file

@ -47,7 +47,7 @@ Note that as of today, in the OSS world, such a “loop” is often coded explic
1. The model reasons once again (using all the messages above) and decides to send a final response "In 2023, Denver Nuggets played against the Miami Heat in the NBA finals." to the executor
1. The executor returns the response directly to the user (since there is no tool call to be executed.)
The sequence diagram that details the steps is here.
The sequence diagram that details the steps is [here](https://github.com/meta-llama/llama-agentic-system/blob/main/docs/sequence-diagram.md).
* /memory_banks - to support creating multiple repositories of data that can be available for agentic systems
* /agentic_system - to support creating and running agentic systems. The sub-APIs support the creation and management of the steps, turns, and sessions within agentic applications.

View file

@ -16,11 +16,16 @@ def read_requirements():
setup(
name="llama_toolchain",
version="0.0.8",
version="0.0.10",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama toolchain",
entry_points={"console_scripts": ["llama = llama_toolchain.cli.llama:main"]},
entry_points={
"console_scripts": [
"llama = llama_toolchain.cli.llama:main",
"install-wheel-from-presigned = llama_toolchain.cli.scripts.run:install_wheel_from_presigned",
]
},
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/meta-llama/llama-toolchain",