Support for Llama3.2 models and Swift SDK (#98)

This commit is contained in:
Ashwin Bharambe 2024-09-25 10:29:58 -07:00 committed by GitHub
parent 95abbf576b
commit 56aed59eb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 3745 additions and 630 deletions

View file

@ -94,14 +94,16 @@ class AgentsClient(Agents):
print(f"Error with parsing or validation: {e}")
async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
async def _run_agent(
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
):
agent_config = AgentConfig(
model="Meta-Llama3.1-8B-Instruct",
model=model,
instructions="You are a helpful assistant",
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
sampling_params=SamplingParams(temperature=0.6, top_p=0.9),
tools=tool_definitions,
tool_choice=ToolChoice.auto,
tool_prompt_format=ToolPromptFormat.function_tag,
tool_prompt_format=tool_prompt_format,
enable_session_persistence=False,
)
@ -130,7 +132,8 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
log.print()
async def run_main(host: str, port: int):
async def run_llama_3_1(host: str, port: int):
model = "Llama3.1-8B-Instruct"
api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [
@ -167,10 +170,11 @@ async def run_main(host: str, port: int):
"Write code to check if a number is prime. Use that to check if 7 is prime",
"What is the boiling point of polyjuicepotion ?",
]
await _run_agent(api, tool_definitions, user_prompts)
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
async def run_rag(host: str, port: int):
async def run_llama_3_2_rag(host: str, port: int):
model = "Llama3.2-3B-Instruct"
api = AgentsClient(f"http://{host}:{port}")
urls = [
@ -206,12 +210,71 @@ async def run_rag(host: str, port: int):
"Tell me briefly about llama3 and torchtune",
]
await _run_agent(api, tool_definitions, user_prompts, attachments)
await _run_agent(
api, model, tool_definitions, ToolPromptFormat.json, user_prompts, attachments
)
def main(host: str, port: int, rag: bool = False):
fn = run_rag if rag else run_main
asyncio.run(fn(host, port))
async def run_llama_3_2(host: str, port: int):
model = "Llama3.2-3B-Instruct"
api = AgentsClient(f"http://{host}:{port}")
# zero shot tools for llama3.2 text models
tool_definitions = [
FunctionCallToolDefinition(
function_name="get_boiling_point",
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
parameters={
"liquid_name": ToolParamDefinition(
param_type="str",
description="The name of the liquid",
required=True,
),
"celcius": ToolParamDefinition(
param_type="bool",
description="Whether to return the boiling point in Celcius",
required=False,
),
},
),
FunctionCallToolDefinition(
function_name="make_web_search",
description="Search the web / internet for more realtime information",
parameters={
"query": ToolParamDefinition(
param_type="str",
description="the query to search for",
required=True,
),
},
),
]
user_prompts = [
"Who are you?",
"what is the 100th prime number?",
"Who was 44th President of USA?",
# multiple tool calls in a single prompt
"What is the boiling point of polyjuicepotion and pinkponklyjuice?",
]
await _run_agent(
api, model, tool_definitions, ToolPromptFormat.python_list, user_prompts
)
def main(host: str, port: int, run_type: str):
assert run_type in [
"tools_llama_3_1",
"tools_llama_3_2",
"rag_llama_3_2",
], f"Invalid run type {run_type}, must be one of tools_llama_3_1, tools_llama_3_2, rag_llama_3_2"
fn = {
"tools_llama_3_1": run_llama_3_1,
"tools_llama_3_2": run_llama_3_2,
"rag_llama_3_2": run_llama_3_2_rag,
}
asyncio.run(fn[run_type](host, port))
if __name__ == "__main__":