Support for Llama3.2 models and Swift SDK (#98)

This commit is contained in:
Ashwin Bharambe 2024-09-25 10:29:58 -07:00 committed by GitHub
parent 95abbf576b
commit 56aed59eb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 3745 additions and 630 deletions

View file

@ -94,14 +94,16 @@ class AgentsClient(Agents):
print(f"Error with parsing or validation: {e}")
async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
async def _run_agent(
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
):
agent_config = AgentConfig(
model="Meta-Llama3.1-8B-Instruct",
model=model,
instructions="You are a helpful assistant",
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
sampling_params=SamplingParams(temperature=0.6, top_p=0.9),
tools=tool_definitions,
tool_choice=ToolChoice.auto,
tool_prompt_format=ToolPromptFormat.function_tag,
tool_prompt_format=tool_prompt_format,
enable_session_persistence=False,
)
@ -130,7 +132,8 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
log.print()
async def run_main(host: str, port: int):
async def run_llama_3_1(host: str, port: int):
model = "Llama3.1-8B-Instruct"
api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [
@ -167,10 +170,11 @@ async def run_main(host: str, port: int):
"Write code to check if a number is prime. Use that to check if 7 is prime",
"What is the boiling point of polyjuicepotion ?",
]
await _run_agent(api, tool_definitions, user_prompts)
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
async def run_rag(host: str, port: int):
async def run_llama_3_2_rag(host: str, port: int):
model = "Llama3.2-3B-Instruct"
api = AgentsClient(f"http://{host}:{port}")
urls = [
@ -206,12 +210,71 @@ async def run_rag(host: str, port: int):
"Tell me briefly about llama3 and torchtune",
]
await _run_agent(api, tool_definitions, user_prompts, attachments)
await _run_agent(
api, model, tool_definitions, ToolPromptFormat.json, user_prompts, attachments
)
def main(host: str, port: int, rag: bool = False):
fn = run_rag if rag else run_main
asyncio.run(fn(host, port))
async def run_llama_3_2(host: str, port: int):
model = "Llama3.2-3B-Instruct"
api = AgentsClient(f"http://{host}:{port}")
# zero shot tools for llama3.2 text models
tool_definitions = [
FunctionCallToolDefinition(
function_name="get_boiling_point",
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
parameters={
"liquid_name": ToolParamDefinition(
param_type="str",
description="The name of the liquid",
required=True,
),
"celcius": ToolParamDefinition(
param_type="bool",
description="Whether to return the boiling point in Celcius",
required=False,
),
},
),
FunctionCallToolDefinition(
function_name="make_web_search",
description="Search the web / internet for more realtime information",
parameters={
"query": ToolParamDefinition(
param_type="str",
description="the query to search for",
required=True,
),
},
),
]
user_prompts = [
"Who are you?",
"what is the 100th prime number?",
"Who was 44th President of USA?",
# multiple tool calls in a single prompt
"What is the boiling point of polyjuicepotion and pinkponklyjuice?",
]
await _run_agent(
api, model, tool_definitions, ToolPromptFormat.python_list, user_prompts
)
def main(host: str, port: int, run_type: str):
assert run_type in [
"tools_llama_3_1",
"tools_llama_3_2",
"rag_llama_3_2",
], f"Invalid run type {run_type}, must be one of tools_llama_3_1, tools_llama_3_2, rag_llama_3_2"
fn = {
"tools_llama_3_1": run_llama_3_1,
"tools_llama_3_2": run_llama_3_2,
"rag_llama_3_2": run_llama_3_2_rag,
}
asyncio.run(fn[run_type](host, port))
if __name__ == "__main__":

View file

@ -10,6 +10,10 @@ from typing import Any, AsyncGenerator, List, Optional
import fire
import httpx
from llama_models.llama3.api.datatypes import ImageMedia, URL
from PIL import Image as PIL_Image
from pydantic import BaseModel
from llama_models.llama3.api import * # noqa: F403
@ -105,7 +109,7 @@ async def run_main(host: str, port: int, stream: bool):
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Meta-Llama3.1-8B-Instruct",
model="Llama3.1-8B-Instruct",
messages=[message],
stream=stream,
)
@ -113,8 +117,34 @@ async def run_main(host: str, port: int, stream: bool):
log.print()
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
async def run_mm_main(host: str, port: int, stream: bool, path: str):
client = InferenceClient(f"http://{host}:{port}")
with open(path, "rb") as f:
img = PIL_Image.open(f).convert("RGB")
message = UserMessage(
content=[
ImageMedia(image=URL(uri=f"file://{path}")),
# ImageMedia(image=img),
"Describe this image in two sentences",
],
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
model="Llama3.2-11B-Vision-Instruct",
messages=[message],
stream=stream,
)
async for log in EventLogger().log(iterator):
log.print()
def main(host: str, port: int, stream: bool = True, mm: bool = False, file: str = None):
if mm:
asyncio.run(run_mm_main(host, port, stream, file))
else:
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":

View file

@ -7,11 +7,11 @@
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig
from pydantic import BaseModel, Field
@json_schema_type

View file

@ -51,6 +51,11 @@ class SafetyClient(Safety):
),
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps(
{
"together_api_key": "1882f9a484fc7c6ce3e4dc90272d5db52346c93838daab3d704803181f396b22"
}
),
},
timeout=20,
)