mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Support for Llama3.2 models and Swift SDK (#98)
This commit is contained in:
parent
95abbf576b
commit
56aed59eb4
56 changed files with 3745 additions and 630 deletions
|
@ -94,14 +94,16 @@ class AgentsClient(Agents):
|
|||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
|
||||
async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
||||
async def _run_agent(
|
||||
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
|
||||
):
|
||||
agent_config = AgentConfig(
|
||||
model="Meta-Llama3.1-8B-Instruct",
|
||||
model=model,
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
|
||||
sampling_params=SamplingParams(temperature=0.6, top_p=0.9),
|
||||
tools=tool_definitions,
|
||||
tool_choice=ToolChoice.auto,
|
||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
|
@ -130,7 +132,8 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
|||
log.print()
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
async def run_llama_3_1(host: str, port: int):
|
||||
model = "Llama3.1-8B-Instruct"
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
|
@ -167,10 +170,11 @@ async def run_main(host: str, port: int):
|
|||
"Write code to check if a number is prime. Use that to check if 7 is prime",
|
||||
"What is the boiling point of polyjuicepotion ?",
|
||||
]
|
||||
await _run_agent(api, tool_definitions, user_prompts)
|
||||
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
|
||||
|
||||
|
||||
async def run_rag(host: str, port: int):
|
||||
async def run_llama_3_2_rag(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
urls = [
|
||||
|
@ -206,12 +210,71 @@ async def run_rag(host: str, port: int):
|
|||
"Tell me briefly about llama3 and torchtune",
|
||||
]
|
||||
|
||||
await _run_agent(api, tool_definitions, user_prompts, attachments)
|
||||
await _run_agent(
|
||||
api, model, tool_definitions, ToolPromptFormat.json, user_prompts, attachments
|
||||
)
|
||||
|
||||
|
||||
def main(host: str, port: int, rag: bool = False):
|
||||
fn = run_rag if rag else run_main
|
||||
asyncio.run(fn(host, port))
|
||||
async def run_llama_3_2(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
# zero shot tools for llama3.2 text models
|
||||
tool_definitions = [
|
||||
FunctionCallToolDefinition(
|
||||
function_name="get_boiling_point",
|
||||
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||
parameters={
|
||||
"liquid_name": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The name of the liquid",
|
||||
required=True,
|
||||
),
|
||||
"celcius": ToolParamDefinition(
|
||||
param_type="bool",
|
||||
description="Whether to return the boiling point in Celcius",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
),
|
||||
FunctionCallToolDefinition(
|
||||
function_name="make_web_search",
|
||||
description="Search the web / internet for more realtime information",
|
||||
parameters={
|
||||
"query": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="the query to search for",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
user_prompts = [
|
||||
"Who are you?",
|
||||
"what is the 100th prime number?",
|
||||
"Who was 44th President of USA?",
|
||||
# multiple tool calls in a single prompt
|
||||
"What is the boiling point of polyjuicepotion and pinkponklyjuice?",
|
||||
]
|
||||
await _run_agent(
|
||||
api, model, tool_definitions, ToolPromptFormat.python_list, user_prompts
|
||||
)
|
||||
|
||||
|
||||
def main(host: str, port: int, run_type: str):
|
||||
assert run_type in [
|
||||
"tools_llama_3_1",
|
||||
"tools_llama_3_2",
|
||||
"rag_llama_3_2",
|
||||
], f"Invalid run type {run_type}, must be one of tools_llama_3_1, tools_llama_3_2, rag_llama_3_2"
|
||||
|
||||
fn = {
|
||||
"tools_llama_3_1": run_llama_3_1,
|
||||
"tools_llama_3_2": run_llama_3_2,
|
||||
"rag_llama_3_2": run_llama_3_2_rag,
|
||||
}
|
||||
asyncio.run(fn[run_type](host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue