mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 09:01:59 +00:00
Merge branch 'main' into test-modelregistryhelper
This commit is contained in:
commit
7fd8a61b4d
80 changed files with 2918 additions and 386 deletions
|
|
@ -136,12 +136,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
|
||||
image_type = prompt(
|
||||
f"> Enter the image type you want your Llama Stack to be built as ({' or '.join(e.value for e in ImageType)}): ",
|
||||
"> Enter the image type you want your Llama Stack to be built as (use <TAB> to see options): ",
|
||||
completer=WordCompleter([e.value for e in ImageType]),
|
||||
complete_while_typing=True,
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in [e.value for e in ImageType],
|
||||
error_message=f"Invalid image type, please enter {' or '.join(e.value for e in ImageType)}",
|
||||
error_message="Invalid image type. Use <TAB> to see options",
|
||||
),
|
||||
default=ImageType.CONDA.value,
|
||||
)
|
||||
|
||||
if image_type == ImageType.CONDA.value:
|
||||
|
|
@ -317,11 +318,15 @@ def _generate_run_config(
|
|||
to_write = json.loads(run_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
# this path is only invoked when no template is provided
|
||||
cprint(
|
||||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||
color="green",
|
||||
)
|
||||
# Only print this message for non-container builds since it will be displayed before the
|
||||
# container is built
|
||||
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
||||
# makes sense to display this message
|
||||
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
||||
cprint(
|
||||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||
color="green",
|
||||
)
|
||||
return run_config_file
|
||||
|
||||
|
||||
|
|
@ -355,6 +360,13 @@ def _run_stack_build_command_from_build_config(
|
|||
build_file_path = build_dir / f"{image_name}-build.yaml"
|
||||
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
run_config_file = None
|
||||
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
||||
# Only do this if we're building a container image and we're not using a template
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
||||
cprint("Generating run.yaml file", color="green")
|
||||
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||
|
||||
with open(build_file_path, "w") as f:
|
||||
to_write = json.loads(build_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
|
@ -364,6 +376,7 @@ def _run_stack_build_command_from_build_config(
|
|||
build_file_path,
|
||||
image_name,
|
||||
template_or_config=template_name or config_path or str(build_file_path),
|
||||
run_config=run_config_file,
|
||||
)
|
||||
if return_code != 0:
|
||||
raise RuntimeError(f"Failed to build image {image_name}")
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ def build_image(
|
|||
build_file_path: Path,
|
||||
image_name: str,
|
||||
template_or_config: str,
|
||||
run_config: str | None = None,
|
||||
):
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
|
||||
|
||||
|
|
@ -108,6 +109,11 @@ def build_image(
|
|||
container_base,
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
# When building from a config file (not a template), include the run config path in the
|
||||
# build arguments
|
||||
if run_config is not None:
|
||||
args.append(run_config)
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
|
|
|
|||
|
|
@ -19,12 +19,16 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
# mounting is not supported by docker buildx, so we use COPY instead
|
||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||
|
||||
# Path to the run.yaml file in the container
|
||||
RUN_CONFIG_PATH=/app/run.yaml
|
||||
|
||||
BUILD_CONTEXT_DIR=$(pwd)
|
||||
|
||||
if [ "$#" -lt 4 ]; then
|
||||
# This only works for templates
|
||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<run_config>] [<special_pip_deps>]" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
template_or_config="$1"
|
||||
|
|
@ -35,8 +39,27 @@ container_base="$1"
|
|||
shift
|
||||
pip_dependencies="$1"
|
||||
shift
|
||||
special_pip_deps="${1:-}"
|
||||
|
||||
# Handle optional arguments
|
||||
run_config=""
|
||||
special_pip_deps=""
|
||||
|
||||
# Check if there are more arguments
|
||||
# The logics is becoming cumbersom, we should refactor it if we can do better
|
||||
if [ $# -gt 0 ]; then
|
||||
# Check if the argument ends with .yaml
|
||||
if [[ "$1" == *.yaml ]]; then
|
||||
run_config="$1"
|
||||
shift
|
||||
# If there's another argument after .yaml, it must be special_pip_deps
|
||||
if [ $# -gt 0 ]; then
|
||||
special_pip_deps="$1"
|
||||
fi
|
||||
else
|
||||
# If it's not .yaml, it must be special_pip_deps
|
||||
special_pip_deps="$1"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
|
|
@ -75,7 +98,7 @@ WORKDIR /app
|
|||
# We install the Python 3.11 dev headers and build tools so that any
|
||||
# C‑extension wheels (e.g. polyleven, faiss‑cpu) can compile successfully.
|
||||
|
||||
RUN dnf -y update && dnf install -y iputils net-tools wget \
|
||||
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
||||
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
||||
python3.11-setuptools python3.11-devel gcc make && \
|
||||
ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
||||
|
|
@ -119,6 +142,45 @@ EOF
|
|||
done
|
||||
fi
|
||||
|
||||
# Function to get Python command
|
||||
get_python_cmd() {
|
||||
if is_command_available python; then
|
||||
echo "python"
|
||||
elif is_command_available python3; then
|
||||
echo "python3"
|
||||
else
|
||||
echo "Error: Neither python nor python3 is installed. Please install Python to continue." >&2
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
if [ -n "$run_config" ]; then
|
||||
# Copy the run config to the build context since it's an absolute path
|
||||
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
add_to_container << EOF
|
||||
COPY run.yaml $RUN_CONFIG_PATH
|
||||
EOF
|
||||
|
||||
# Parse the run.yaml configuration to identify external provider directories
|
||||
# If external providers are specified, copy their directory to the container
|
||||
# and update the configuration to reference the new container path
|
||||
python_cmd=$(get_python_cmd)
|
||||
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
|
||||
if [ -n "$external_providers_dir" ]; then
|
||||
echo "Copying external providers directory: $external_providers_dir"
|
||||
add_to_container << EOF
|
||||
COPY $external_providers_dir /app/providers.d
|
||||
EOF
|
||||
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d
|
||||
if [ "$(uname)" = "Darwin" ]; then
|
||||
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
|
||||
else
|
||||
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
stack_mount="/app/llama-stack-source"
|
||||
client_mount="/app/llama-stack-client-source"
|
||||
|
||||
|
|
@ -178,15 +240,16 @@ fi
|
|||
RUN pip uninstall -y uv
|
||||
EOF
|
||||
|
||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
||||
if [[ "$template_or_config" != *.yaml ]]; then
|
||||
# If a run config is provided, we use the --config flag
|
||||
if [[ -n "$run_config" ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--config", "$RUN_CONFIG_PATH"]
|
||||
EOF
|
||||
# If a template is provided (not a yaml file), we use the --template flag
|
||||
elif [[ "$template_or_config" != *.yaml ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"]
|
||||
EOF
|
||||
else
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Add other require item commands genearic to all containers
|
||||
|
|
@ -258,9 +321,10 @@ $CONTAINER_BINARY build \
|
|||
"${CLI_ARGS[@]}" \
|
||||
-t "$image_tag" \
|
||||
-f "$TEMP_DIR/Containerfile" \
|
||||
"."
|
||||
"$BUILD_CONTEXT_DIR"
|
||||
|
||||
# clean up tmp/configs
|
||||
rm -f "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
set +x
|
||||
|
||||
echo "Success!"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,11 @@ import asyncio
|
|||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import Field, TypeAdapter
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
|
|
@ -526,7 +531,7 @@ class InferenceRouter(Inference):
|
|||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
|
|
@ -558,6 +563,16 @@ class InferenceRouter(Inference):
|
|||
if model_obj.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
if tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
||||
if tools is None:
|
||||
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||
if tools:
|
||||
for tool in tools:
|
||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
messages=messages,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from fastapi import Body, FastAPI, HTTPException, Request
|
|||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
|
@ -110,6 +111,8 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
|||
)
|
||||
elif isinstance(exc, ValueError):
|
||||
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
||||
elif isinstance(exc, BadRequestError):
|
||||
return HTTPException(status_code=400, detail=str(exc))
|
||||
elif isinstance(exc, PermissionError):
|
||||
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
||||
elif isinstance(exc, TimeoutError):
|
||||
|
|
@ -162,14 +165,17 @@ async def maybe_await(value):
|
|||
return value
|
||||
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
async def sse_generator(event_gen_coroutine):
|
||||
event_gen = None
|
||||
try:
|
||||
async for item in await event_gen:
|
||||
event_gen = await event_gen_coroutine
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
if event_gen:
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
|
|
@ -455,6 +461,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
|||
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||
|
|
|
|||
|
|
@ -24,6 +24,13 @@ def rag_chat_page():
|
|||
def should_disable_input():
|
||||
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
|
||||
|
||||
def log_message(message):
|
||||
with st.chat_message(message["role"]):
|
||||
if "tool_output" in message and message["tool_output"]:
|
||||
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
|
||||
st.write(message["tool_output"])
|
||||
st.markdown(message["content"])
|
||||
|
||||
with st.sidebar:
|
||||
# File/Directory Upload Section
|
||||
st.subheader("Upload Documents", divider=True)
|
||||
|
|
@ -146,8 +153,7 @@ def rag_chat_page():
|
|||
|
||||
# Display chat history
|
||||
for message in st.session_state.displayed_messages:
|
||||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
log_message(message)
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
|
|
@ -201,7 +207,7 @@ def rag_chat_page():
|
|||
|
||||
# Display assistant response
|
||||
with st.chat_message("assistant"):
|
||||
retrieval_message_placeholder = st.empty()
|
||||
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
|
|
@ -209,14 +215,16 @@ def rag_chat_page():
|
|||
log.print()
|
||||
if log.role == "tool_execution":
|
||||
retrieval_response += log.content.replace("====", "").strip()
|
||||
retrieval_message_placeholder.info(retrieval_response)
|
||||
retrieval_message_placeholder.write(retrieval_response)
|
||||
else:
|
||||
full_response += log.content
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
st.session_state.displayed_messages.append({"role": "assistant", "content": full_response})
|
||||
st.session_state.displayed_messages.append(
|
||||
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
|
||||
)
|
||||
|
||||
def direct_process_prompt(prompt):
|
||||
# Add the system prompt in the beginning of the conversation
|
||||
|
|
@ -230,15 +238,14 @@ def rag_chat_page():
|
|||
prompt_context = rag_response.content
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
with st.expander(label="Retrieval Output", expanded=False):
|
||||
st.write(prompt_context)
|
||||
|
||||
retrieval_message_placeholder = st.empty()
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
|
||||
# Display the retrieved content
|
||||
retrieval_response += str(prompt_context)
|
||||
retrieval_message_placeholder.info(retrieval_response)
|
||||
|
||||
# Construct the extended prompt
|
||||
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
|
||||
|
||||
|
|
|
|||
|
|
@ -4,14 +4,23 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import enum
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent
|
||||
from llama_stack_client.lib.agents.react.agent import ReActAgent
|
||||
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
class AgentType(enum.Enum):
|
||||
REGULAR = "Regular"
|
||||
REACT = "ReAct"
|
||||
|
||||
|
||||
def tool_chat_page():
|
||||
st.title("🛠 Tools")
|
||||
|
||||
|
|
@ -23,6 +32,7 @@ def tool_chat_page():
|
|||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||
selected_vector_dbs = []
|
||||
|
||||
def reset_agent():
|
||||
st.session_state.clear()
|
||||
|
|
@ -66,25 +76,36 @@ def tool_chat_page():
|
|||
|
||||
toolgroup_selection.extend(mcp_selection)
|
||||
|
||||
active_tool_list = []
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
active_tool_list.extend(
|
||||
[
|
||||
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
|
||||
for t in client.tools.list(toolgroup_id=toolgroup_id)
|
||||
]
|
||||
)
|
||||
grouped_tools = {}
|
||||
total_tools = 0
|
||||
|
||||
st.markdown(f"Active Tools: 🛠 {len(active_tool_list)}", help="List of currently active tools.")
|
||||
st.json(active_tool_list)
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
||||
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
||||
total_tools += len(tools)
|
||||
|
||||
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||
|
||||
for group_id, tools in grouped_tools.items():
|
||||
with st.expander(f"🔧 Tools from `{group_id}`"):
|
||||
for idx, tool in enumerate(tools, start=1):
|
||||
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
|
||||
|
||||
st.subheader("Agent Configurations")
|
||||
st.subheader("Agent Type")
|
||||
agent_type = st.radio(
|
||||
"Select Agent Type",
|
||||
[AgentType.REGULAR, AgentType.REACT],
|
||||
format_func=lambda x: x.value,
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=1,
|
||||
step=64,
|
||||
help="The maximum number of tokens to generate",
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
|
@ -101,13 +122,27 @@ def tool_chat_page():
|
|||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||
return ReActAgent(
|
||||
client=client,
|
||||
model=model,
|
||||
tools=toolgroup_selection,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": ReActOutput.model_json_schema(),
|
||||
},
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
else:
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
st.session_state.agent_type = agent_type
|
||||
|
||||
agent = create_agent()
|
||||
|
||||
|
|
@ -136,6 +171,158 @@ def tool_chat_page():
|
|||
)
|
||||
|
||||
def response_generator(turn_response):
|
||||
if st.session_state.get("agent_type") == AgentType.REACT:
|
||||
return _handle_react_response(turn_response)
|
||||
else:
|
||||
return _handle_regular_response(turn_response)
|
||||
|
||||
def _handle_react_response(turn_response):
|
||||
current_step_content = ""
|
||||
final_answer = None
|
||||
tool_results = []
|
||||
|
||||
for response in turn_response:
|
||||
if not hasattr(response.event, "payload"):
|
||||
yield (
|
||||
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
|
||||
"The response received is missing an expected `payload` attribute.\n"
|
||||
"This could indicate a malformed response or an internal issue within the server.\n\n"
|
||||
f"Error details: {response}"
|
||||
)
|
||||
return
|
||||
|
||||
payload = response.event.payload
|
||||
|
||||
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
|
||||
current_step_content += payload.delta.text
|
||||
continue
|
||||
|
||||
if payload.event_type == "step_complete":
|
||||
step_details = payload.step_details
|
||||
|
||||
if step_details.step_type == "inference":
|
||||
yield from _process_inference_step(current_step_content, tool_results, final_answer)
|
||||
current_step_content = ""
|
||||
elif step_details.step_type == "tool_execution":
|
||||
tool_results = _process_tool_execution(step_details, tool_results)
|
||||
current_step_content = ""
|
||||
else:
|
||||
current_step_content = ""
|
||||
|
||||
if not final_answer and tool_results:
|
||||
yield from _format_tool_results_summary(tool_results)
|
||||
|
||||
def _process_inference_step(current_step_content, tool_results, final_answer):
|
||||
try:
|
||||
react_output_data = json.loads(current_step_content)
|
||||
thought = react_output_data.get("thought")
|
||||
action = react_output_data.get("action")
|
||||
answer = react_output_data.get("answer")
|
||||
|
||||
if answer and answer != "null" and answer is not None:
|
||||
final_answer = answer
|
||||
|
||||
if thought:
|
||||
with st.expander("🤔 Thinking...", expanded=False):
|
||||
st.markdown(f":grey[__{thought}__]")
|
||||
|
||||
if action and isinstance(action, dict):
|
||||
tool_name = action.get("tool_name")
|
||||
tool_params = action.get("tool_params")
|
||||
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
|
||||
st.json(tool_params)
|
||||
|
||||
if answer and answer != "null" and answer is not None:
|
||||
yield f"\n\n✅ **Final Answer:**\n{answer}"
|
||||
|
||||
except json.JSONDecodeError:
|
||||
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
|
||||
except Exception as e:
|
||||
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
|
||||
|
||||
return final_answer
|
||||
|
||||
def _process_tool_execution(step_details, tool_results):
|
||||
try:
|
||||
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
|
||||
for tool_response in step_details.tool_responses:
|
||||
tool_name = tool_response.tool_name
|
||||
content = tool_response.content
|
||||
tool_results.append((tool_name, content))
|
||||
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
st.json(parsed_content)
|
||||
except json.JSONDecodeError:
|
||||
st.code(content, language=None)
|
||||
else:
|
||||
with st.expander("⚙️ Observation", expanded=False):
|
||||
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
|
||||
except Exception as e:
|
||||
with st.expander("⚙️ Error in Tool Execution", expanded=False):
|
||||
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
|
||||
|
||||
return tool_results
|
||||
|
||||
def _format_tool_results_summary(tool_results):
|
||||
yield "\n\n**Here's what I found:**\n"
|
||||
for tool_name, content in tool_results:
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
|
||||
if tool_name == "web_search" and "top_k" in parsed_content:
|
||||
yield from _format_web_search_results(parsed_content)
|
||||
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
|
||||
yield from _format_results_list(parsed_content["results"])
|
||||
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
|
||||
yield from _format_dict_results(parsed_content)
|
||||
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
|
||||
yield from _format_list_results(parsed_content)
|
||||
except json.JSONDecodeError:
|
||||
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
|
||||
except (TypeError, AttributeError, KeyError, IndexError) as e:
|
||||
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
|
||||
|
||||
def _format_web_search_results(parsed_content):
|
||||
for i, result in enumerate(parsed_content["top_k"], 1):
|
||||
if i <= 3:
|
||||
title = result.get("title", "Untitled")
|
||||
url = result.get("url", "")
|
||||
content_text = result.get("content", "").strip()
|
||||
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
|
||||
|
||||
def _format_results_list(results):
|
||||
for i, result in enumerate(results, 1):
|
||||
if i <= 3:
|
||||
if isinstance(result, dict):
|
||||
name = result.get("name", result.get("title", "Result " + str(i)))
|
||||
description = result.get("description", result.get("content", result.get("summary", "")))
|
||||
yield f"\n- **{name}**\n {description}\n"
|
||||
else:
|
||||
yield f"\n- {result}\n"
|
||||
|
||||
def _format_dict_results(parsed_content):
|
||||
yield "\n```\n"
|
||||
for key, value in list(parsed_content.items())[:5]:
|
||||
if isinstance(value, str) and len(value) < 100:
|
||||
yield f"{key}: {value}\n"
|
||||
else:
|
||||
yield f"{key}: [Complex data]\n"
|
||||
yield "```\n"
|
||||
|
||||
def _format_list_results(parsed_content):
|
||||
yield "\n"
|
||||
for _, item in enumerate(parsed_content[:3], 1):
|
||||
if isinstance(item, str):
|
||||
yield f"- {item}\n"
|
||||
elif isinstance(item, dict) and "text" in item:
|
||||
yield f"- {item['text']}\n"
|
||||
elif isinstance(item, dict) and len(item) > 0:
|
||||
first_value = next(iter(item.values()))
|
||||
if isinstance(first_value, str) and len(first_value) < 100:
|
||||
yield f"- {first_value}\n"
|
||||
|
||||
def _handle_regular_response(turn_response):
|
||||
for response in turn_response:
|
||||
if hasattr(response.event, "payload"):
|
||||
print(response.event.payload)
|
||||
|
|
@ -144,14 +331,18 @@ def tool_chat_page():
|
|||
yield response.event.payload.delta.text
|
||||
if response.event.payload.event_type == "step_complete":
|
||||
if response.event.payload.step_details.step_type == "tool_execution":
|
||||
yield " 🛠 "
|
||||
if response.event.payload.step_details.tool_calls:
|
||||
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
|
||||
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
|
||||
else:
|
||||
yield "No tool_calls present in step_details"
|
||||
else:
|
||||
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
response = st.write_stream(response_generator(turn_response))
|
||||
response_content = st.write_stream(response_generator(turn_response))
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": response})
|
||||
st.session_state.messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
|
||||
tool_chat_page()
|
||||
|
|
|
|||
|
|
@ -303,6 +303,7 @@ class ChatFormat:
|
|||
arguments_json=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
||||
return RawMessage(
|
||||
role="assistant",
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ This example passes an image that is smaller than the tile size, to show the til
|
|||
|
||||
##### Model Response Format
|
||||
```
|
||||
The image depicts a dog standing on a skateboard, with its front paws positioned on the board and its back paws hanging off the back. The dog has a distinctive coat pattern, featuring a white face, brown and black fur, and white paws, and is standing on a skateboard with red wheels, set against a blurred background of a street or alleyway with a teal door and beige wall.<|eot|>
|
||||
The image depicts a dog standing on a skateboard, positioned centrally and facing the camera directly. The dog has a distinctive coat pattern featuring white, black, and brown fur, with floppy ears and a black nose, and is standing on a skateboard with red wheels.<|eot|>
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ Here is an example of how to pass an image to the model
|
|||
|
||||
##### Model Response Format
|
||||
```
|
||||
This image shows a dog standing on a skateboard, with its front paws positioned near the front of the board and its back paws near the back. The dog has a white, black, and orange coat, and is standing on a gray skateboard with red wheels, in front of a blurred background that appears to be a street or alleyway.<|eot|>
|
||||
The image depicts a dog standing on a skateboard, with the dog positioned centrally and facing forward. The dog has a distinctive coat featuring a mix of white, brown, and black fur, and is wearing a collar as it stands on the skateboard, which has red wheels.<|eot|>
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -117,7 +117,7 @@ Here is an example of how to pass an image to the model
|
|||
|
||||
##### Model Response Format
|
||||
```
|
||||
The first image shows a dog standing on a skateboard, while the second image shows a plate of spaghetti with tomato sauce, parmesan cheese, and parsley. The two images are unrelated, with the first image featuring a dog and the second image featuring a food dish, and they do not share any common elements or themes.<|eot|>
|
||||
The first image features a dog standing on a skateboard, while the second image showcases a plate of spaghetti with tomato sauce and cheese. The two images appear to be unrelated, with one depicting a playful scene of a dog on a skateboard and the other presenting a classic Italian dish.<|eom|>
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -135,13 +135,44 @@ We are continuing the format for zero shot function calling used in previous ver
|
|||
```
|
||||
<|begin_of_text|><|header_start|>system<|header_end|>
|
||||
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
|
||||
|
||||
1. FUNCTION CALLS:
|
||||
- ONLY use functions that are EXPLICITLY listed in the function list below
|
||||
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
|
||||
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
|
||||
Examples:
|
||||
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
|
||||
INCORRECT: get_weather(location="New York")
|
||||
INCORRECT: Let me check the weather: [get_weather(location="New York")]
|
||||
INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
||||
|
||||
2. RESPONSE RULES:
|
||||
- For pure function requests matching a listed function: ONLY output the function call(s)
|
||||
- For knowledge questions: ONLY output text
|
||||
- For missing parameters: ONLY request the specific missing parameters
|
||||
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
|
||||
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
|
||||
- NEVER combine text and function calls in the same response
|
||||
- NEVER suggest alternative functions when the requested service is unavailable
|
||||
- NEVER create or invent new functions not listed below
|
||||
|
||||
3. STRICT BOUNDARIES:
|
||||
- ONLY use functions from the list below - no exceptions
|
||||
- NEVER use a function as an alternative to unavailable information
|
||||
- NEVER call functions not present in the function list
|
||||
- NEVER add explanatory text to function calls
|
||||
- NEVER respond with empty brackets
|
||||
- Use proper Python/JSON syntax for function calls
|
||||
- Check the function list carefully before responding
|
||||
|
||||
4. TOOL RESPONSE HANDLING:
|
||||
- When receiving tool responses: provide concise, natural language responses
|
||||
- Don't repeat tool response verbatim
|
||||
- Don't add supplementary information
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
|
|
@ -151,9 +182,7 @@ Here is a list of functions in JSON format that you can invoke.
|
|||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"city"
|
||||
],
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
|
|
@ -167,7 +196,10 @@ Here is a list of functions in JSON format that you can invoke.
|
|||
}
|
||||
}
|
||||
}
|
||||
<|eot|><|header_start|>user<|header_end|>
|
||||
]
|
||||
|
||||
You can answer general questions or invoke tools when necessary.
|
||||
In addition to tool calls, you should also augment your responses by using the tool outputs.<|eot|><|header_start|>user<|header_end|>
|
||||
|
||||
What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|>
|
||||
|
||||
|
|
@ -176,7 +208,7 @@ What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_e
|
|||
|
||||
##### Model Response Format
|
||||
```
|
||||
[get_weather(city='SF'), get_weather(city='Seattle')]<|eot|>
|
||||
[get_weather(city="San Francisco"), get_weather(city="Seattle")]<|eot|>
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -273,5 +305,5 @@ Use tools to get latest trending songs<|eot|><|header_start|>assistant<|header_e
|
|||
|
||||
##### Model Response Format
|
||||
```
|
||||
<function=trending_songs>{"n": "10"}</function><|eot|>
|
||||
<function=trending_songs>{"n": 10}</function><|eot|>
|
||||
```
|
||||
|
|
|
|||
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||
PromptTemplate,
|
||||
PromptTemplateGeneratorBase,
|
||||
)
|
||||
|
||||
|
||||
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||
DEFAULT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
|
||||
|
||||
1. FUNCTION CALLS:
|
||||
- ONLY use functions that are EXPLICITLY listed in the function list below
|
||||
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
|
||||
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
|
||||
Examples:
|
||||
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
|
||||
INCORRECT: get_weather(location="New York")
|
||||
INCORRECT: Let me check the weather: [get_weather(location="New York")]
|
||||
INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
||||
|
||||
2. RESPONSE RULES:
|
||||
- For pure function requests matching a listed function: ONLY output the function call(s)
|
||||
- For knowledge questions: ONLY output text
|
||||
- For missing parameters: ONLY request the specific missing parameters
|
||||
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
|
||||
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
|
||||
- NEVER combine text and function calls in the same response
|
||||
- NEVER suggest alternative functions when the requested service is unavailable
|
||||
- NEVER create or invent new functions not listed below
|
||||
|
||||
3. STRICT BOUNDARIES:
|
||||
- ONLY use functions from the list below - no exceptions
|
||||
- NEVER use a function as an alternative to unavailable information
|
||||
- NEVER call functions not present in the function list
|
||||
- NEVER add explanatory text to function calls
|
||||
- NEVER respond with empty brackets
|
||||
- Use proper Python/JSON syntax for function calls
|
||||
- Check the function list carefully before responding
|
||||
|
||||
4. TOOL RESPONSE HANDLING:
|
||||
- When receiving tool responses: provide concise, natural language responses
|
||||
- Don't repeat tool response verbatim
|
||||
- Don't add supplementary information
|
||||
|
||||
|
||||
{{ function_description }}
|
||||
""".strip("\n")
|
||||
)
|
||||
|
||||
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||
return PromptTemplate(
|
||||
system_prompt,
|
||||
{"function_description": self._gen_function_description(custom_tools)},
|
||||
)
|
||||
|
||||
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{% for t in tools -%}
|
||||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{
|
||||
"name": "{{tname}}",
|
||||
"description": "{{tdesc}}",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": {{ required_params | tojson }},
|
||||
"properties": {
|
||||
{%- for name, param in tparams.items() %}
|
||||
"{{name}}": {
|
||||
"type": "{{param.param_type}}",
|
||||
"description": "{{param.description}}"{% if param.default %},
|
||||
"default": "{{param.default}}"{% endif %}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
}
|
||||
}
|
||||
}{% if not loop.last %},
|
||||
{% endif -%}
|
||||
{%- endfor %}
|
||||
]
|
||||
|
||||
You can answer general questions or invoke tools when necessary.
|
||||
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
template_str.strip("\n"),
|
||||
{"tools": [t.model_dump() for t in custom_tools]},
|
||||
).render()
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get weather info for places",
|
||||
parameters={
|
||||
"city": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The name of the city to get the weather for",
|
||||
required=True,
|
||||
),
|
||||
"metric": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
]
|
||||
|
|
@ -9,6 +9,10 @@ from io import BytesIO
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator,
|
||||
)
|
||||
|
||||
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||
from ..prompt_format import (
|
||||
Llama4UseCase,
|
||||
|
|
@ -177,39 +181,9 @@ def usecases(base_model: bool = False) -> List[UseCase | str]:
|
|||
[
|
||||
RawMessage(
|
||||
role="system",
|
||||
content="""You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"city"
|
||||
],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""",
|
||||
content=PythonListCustomToolGenerator()
|
||||
.gen(PythonListCustomToolGenerator().data_examples()[0])
|
||||
.render(),
|
||||
),
|
||||
RawMessage(
|
||||
role="user",
|
||||
|
|
|
|||
|
|
@ -253,7 +253,8 @@ class MetaReferenceInferenceImpl(
|
|||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.completion(request):
|
||||
for token_results in self.generator.completion([request]):
|
||||
token_result = token_results[0]
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
|
|
|
|||
|
|
@ -69,7 +69,10 @@ class CancelSentinel(BaseModel):
|
|||
|
||||
class TaskRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
|
||||
task: Tuple[
|
||||
str,
|
||||
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||
]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
|
|
@ -231,10 +234,10 @@ def worker_process_entrypoint(
|
|||
while True:
|
||||
try:
|
||||
task = req_gen.send(result)
|
||||
if isinstance(task, str) and task == EndSentinel():
|
||||
if isinstance(task, EndSentinel):
|
||||
break
|
||||
|
||||
assert isinstance(task, TaskRequest)
|
||||
assert isinstance(task, TaskRequest), task
|
||||
result = model(task.task)
|
||||
except StopIteration:
|
||||
break
|
||||
|
|
@ -331,7 +334,10 @@ class ModelParallelProcessGroup:
|
|||
|
||||
def run_inference(
|
||||
self,
|
||||
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
|
||||
req: Tuple[
|
||||
str,
|
||||
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||
],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
|
|
@ -153,6 +154,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
)
|
||||
)
|
||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
picked.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
|
||||
)
|
||||
)
|
||||
|
||||
return RAGQueryResult(
|
||||
content=picked,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
|
|
@ -25,4 +25,22 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.agents,
|
||||
],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.eval,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"requests",
|
||||
],
|
||||
module="llama_stack.providers.remote.eval.nvidia",
|
||||
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||
),
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
Api.scoring,
|
||||
Api.inference,
|
||||
Api.agents,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -288,4 +288,14 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="watsonx",
|
||||
pip_packages=["ibm_watson_machine_learning"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
5
llama_stack/providers/remote/eval/__init__.py
Normal file
5
llama_stack/providers/remote/eval/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
134
llama_stack/providers/remote/eval/nvidia/README.md
Normal file
134
llama_stack/providers/remote/eval/nvidia/README.md
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
# NVIDIA NeMo Evaluator Eval Provider
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used.
|
||||
|
||||
Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation.
|
||||
|
||||
### Example for register an academic benchmark
|
||||
|
||||
```
|
||||
POST /eval/benchmarks
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "mmlu",
|
||||
"dataset_id": "",
|
||||
"scoring_functions": [],
|
||||
"metadata": {
|
||||
"type": "mmlu"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example for register a custom evaluation
|
||||
|
||||
```
|
||||
POST /eval/benchmarks
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "my-custom-benchmark",
|
||||
"dataset_id": "",
|
||||
"scoring_functions": [],
|
||||
"metadata": {
|
||||
"type": "custom",
|
||||
"params": {
|
||||
"parallelism": 8
|
||||
},
|
||||
"tasks": {
|
||||
"qa": {
|
||||
"type": "completion",
|
||||
"params": {
|
||||
"template": {
|
||||
"prompt": "{{prompt}}",
|
||||
"max_tokens": 200
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl"
|
||||
},
|
||||
"metrics": {
|
||||
"bleu": {
|
||||
"type": "bleu",
|
||||
"params": {
|
||||
"references": [
|
||||
"{{ideal_response}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example for triggering a benchmark/custom evaluation
|
||||
|
||||
```
|
||||
POST /eval/benchmarks/{benchmark_id}/jobs
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "my-custom-benchmark",
|
||||
"benchmark_config": {
|
||||
"eval_candidate": {
|
||||
"type": "model",
|
||||
"model": "meta-llama/Llama3.1-8B-Instruct",
|
||||
"sampling_params": {
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.7
|
||||
}
|
||||
},
|
||||
"scoring_params": {}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Response example:
|
||||
```json
|
||||
{
|
||||
"job_id": "eval-1234",
|
||||
"status": "in_progress"
|
||||
}
|
||||
```
|
||||
|
||||
### Example for getting the status of a job
|
||||
```
|
||||
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
||||
```
|
||||
|
||||
Response example:
|
||||
```json
|
||||
{
|
||||
"job_id": "eval-1234",
|
||||
"status": "in_progress"
|
||||
}
|
||||
```
|
||||
|
||||
### Example for cancelling a job
|
||||
```
|
||||
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
|
||||
```
|
||||
|
||||
### Example for getting the results
|
||||
```
|
||||
GET /eval/benchmarks/{benchmark_id}/results
|
||||
```
|
||||
```json
|
||||
{
|
||||
"generations": [],
|
||||
"scores": {
|
||||
"{benchmark_id}": {
|
||||
"score_rows": [],
|
||||
"aggregated_results": {
|
||||
"tasks": {},
|
||||
"groups": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
31
llama_stack/providers/remote/eval/nvidia/__init__.py
Normal file
31
llama_stack/providers/remote/eval/nvidia/__init__.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import NVIDIAEvalConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: NVIDIAEvalConfig,
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .eval import NVIDIAEvalImpl
|
||||
|
||||
impl = NVIDIAEvalImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
deps[Api.scoring],
|
||||
deps[Api.inference],
|
||||
deps[Api.agents],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"]
|
||||
29
llama_stack/providers/remote/eval/nvidia/config.py
Normal file
29
llama_stack/providers/remote/eval/nvidia/config.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NVIDIAEvalConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
|
||||
|
||||
Attributes:
|
||||
evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
|
||||
"""
|
||||
|
||||
evaluator_url: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
|
||||
description="The url for accessing the evaluator service",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
|
||||
}
|
||||
154
llama_stack/providers/remote/eval/nvidia/eval.py
Normal file
154
llama_stack/providers/remote/eval/nvidia/eval.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring, ScoringResult
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .config import NVIDIAEvalConfig
|
||||
|
||||
DEFAULT_NAMESPACE = "nvidia"
|
||||
|
||||
|
||||
class NVIDIAEvalImpl(
|
||||
Eval,
|
||||
BenchmarksProtocolPrivate,
|
||||
ModelRegistryHelper,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: NVIDIAEvalConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
scoring_api: Scoring,
|
||||
inference_api: Inference,
|
||||
agents_api: Agents,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.scoring_api = scoring_api
|
||||
self.inference_api = inference_api
|
||||
self.agents_api = agents_api
|
||||
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def _evaluator_get(self, path):
|
||||
"""Helper for making GET requests to the evaluator service."""
|
||||
response = requests.get(url=f"{self.config.evaluator_url}{path}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _evaluator_post(self, path, data):
|
||||
"""Helper for making POST requests to the evaluator service."""
|
||||
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def register_benchmark(self, task_def: Benchmark) -> None:
|
||||
"""Register a benchmark as an evaluation configuration."""
|
||||
await self._evaluator_post(
|
||||
"/v1/evaluation/configs",
|
||||
{
|
||||
"namespace": DEFAULT_NAMESPACE,
|
||||
"name": task_def.benchmark_id,
|
||||
# metadata is copied to request body as-is
|
||||
**task_def.metadata,
|
||||
},
|
||||
)
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
"""Run an evaluation job for a benchmark."""
|
||||
model = (
|
||||
benchmark_config.eval_candidate.model
|
||||
if benchmark_config.eval_candidate.type == "model"
|
||||
else benchmark_config.eval_candidate.config.model
|
||||
)
|
||||
nvidia_model = self.get_provider_model_id(model) or model
|
||||
|
||||
result = await self._evaluator_post(
|
||||
"/v1/evaluation/jobs",
|
||||
{
|
||||
"config": f"{DEFAULT_NAMESPACE}/{benchmark_id}",
|
||||
"target": {"type": "model", "model": nvidia_model},
|
||||
},
|
||||
)
|
||||
|
||||
return Job(job_id=result["id"], status=JobStatus.in_progress)
|
||||
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of an evaluation job.
|
||||
|
||||
EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed".
|
||||
JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed"
|
||||
"""
|
||||
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}")
|
||||
result_status = result["status"]
|
||||
|
||||
job_status = JobStatus.failed
|
||||
if result_status in ["created", "pending"]:
|
||||
job_status = JobStatus.scheduled
|
||||
elif result_status in ["running"]:
|
||||
job_status = JobStatus.in_progress
|
||||
elif result_status in ["completed"]:
|
||||
job_status = JobStatus.completed
|
||||
elif result_status in ["cancelled"]:
|
||||
job_status = JobStatus.cancelled
|
||||
|
||||
return Job(job_id=job_id, status=job_status)
|
||||
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel the evaluation job."""
|
||||
await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {})
|
||||
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Returns the results of the evaluation job."""
|
||||
|
||||
job = await self.job_status(benchmark_id, job_id)
|
||||
status = job.status
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job {job_id} not completed. Status: {status.value}")
|
||||
|
||||
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results")
|
||||
|
||||
return EvaluateResponse(
|
||||
# TODO: these are stored in detailed results on NeMo Evaluator side; can be added
|
||||
generations=[],
|
||||
scores={
|
||||
benchmark_id: ScoringResult(
|
||||
score_rows=[],
|
||||
aggregated_results=result,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
|
@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel):
|
|||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
)
|
||||
append_api_version: bool = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
|
||||
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
|
|||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
|
|
@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference import (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
|
@ -120,10 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
base_url = f"{self._config.url}/v1"
|
||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
|
||||
return _get_client_for_base_url(base_url)
|
||||
|
||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||
|
|
@ -387,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Allow non-llama model registration.
|
||||
|
||||
Non-llama model registration: API Catalogue models, post-training models, etc.
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.models.register(
|
||||
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
|
||||
model_type=ModelType.llm,
|
||||
provider_id="nvidia",
|
||||
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
|
||||
)
|
||||
|
||||
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
|
||||
"""
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
llama_model = model.metadata.get("llama_model")
|
||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||
if existing_llama_model:
|
||||
if existing_llama_model != llama_model:
|
||||
raise ValueError(
|
||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
# not llama model
|
||||
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
else:
|
||||
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from ollama import AsyncClient
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
|
@ -73,6 +72,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
from ollama import AsyncClient # type: ignore[attr-defined]
|
||||
|
||||
from .models import model_entries
|
||||
|
||||
|
|
|
|||
|
|
@ -76,8 +76,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
# Together client has no close method, so just set to None
|
||||
self._client = None
|
||||
if self._openai_client:
|
||||
await self._openai_client.close()
|
||||
self._openai_client = None
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -359,7 +362,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", True):
|
||||
if params.get("stream", False):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
|
||||
|
|
|
|||
|
|
@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
||||
self.client = AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
)
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
def _lazy_initialize_client(self):
|
||||
if self.client is not None:
|
||||
return
|
||||
|
||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self):
|
||||
return AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
|
@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
|
@ -357,12 +368,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
assert self.client is not None
|
||||
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
|
||||
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
|
||||
# Changing this may lead to unpredictable behavior.
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
res = await self.client.models.list()
|
||||
res = await client.models.list()
|
||||
available_models = [m.id async for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
|
|
@ -413,6 +427,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
|
|
@ -452,6 +467,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
extra_body: Dict[str, Any] = {}
|
||||
|
|
@ -508,6 +524,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
|
|
|
|||
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import WatsonXConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
|
||||
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||
from .watsonx import WatsonXInferenceAdapter
|
||||
|
||||
if not isinstance(config, WatsonXConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
adapter = WatsonXInferenceAdapter(config)
|
||||
return adapter
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "WatsonXConfig"]
|
||||
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class WatsonXProviderDataValidator(BaseModel):
|
||||
url: str
|
||||
api_key: str
|
||||
project_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WatsonXConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
)
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
||||
description="The watsonx API key, only needed of using the hosted service",
|
||||
)
|
||||
project_id: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
||||
description="The Project ID key, only needed of using the hosted service",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
|
||||
"api_key": "${env.WATSONX_API_KEY:}",
|
||||
"project_id": "${env.WATSONX_PROJECT_ID:}",
|
||||
}
|
||||
47
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
47
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-2-13b-chat",
|
||||
CoreModelId.llama2_13b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
378
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
378
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
|
|
@ -0,0 +1,378 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from ibm_watson_machine_learning.foundation_models import Model
|
||||
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
GreedySamplingStrategy,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from . import WatsonXConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: WatsonXConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
|
||||
print(f"Initializing watsonx InferenceAdapter({config.url})...")
|
||||
|
||||
self._config = config
|
||||
|
||||
self._project_id = self._config.project_id
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_client(self, model_id) -> Model:
|
||||
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||
config_url = self._config.url
|
||||
project_id = self._config.project_id
|
||||
credentials = {"url": config_url, "apikey": config_api_key}
|
||||
|
||||
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
if not self._openai_client:
|
||||
self._openai_client = AsyncOpenAI(
|
||||
base_url=f"{self._config.url}/openai/v1",
|
||||
api_key=self._config.api_key,
|
||||
)
|
||||
return self._openai_client
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client(request.model).generate(**params)
|
||||
choices = []
|
||||
if "results" in r:
|
||||
for result in r["results"]:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||
text=result["generated_text"],
|
||||
)
|
||||
choices.append(choice)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=choices,
|
||||
)
|
||||
return process_completion_response(response)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = self._get_client(request.model).generate_text_stream(**params)
|
||||
for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=None,
|
||||
text=chunk,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client(request.model).generate(**params)
|
||||
choices = []
|
||||
if "results" in r:
|
||||
for result in r["results"]:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||
text=result["generated_text"],
|
||||
)
|
||||
choices.append(choice)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=choices,
|
||||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
model_id = request.model
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_client(model_id).generate_text_stream(**params)
|
||||
for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=None,
|
||||
text=chunk,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
input_dict = {"params": {}}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
else:
|
||||
assert not media_present, "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
if request.sampling_params:
|
||||
if request.sampling_params.strategy:
|
||||
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
||||
if request.sampling_params.max_tokens:
|
||||
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
||||
if request.sampling_params.repetition_penalty:
|
||||
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
||||
|
||||
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
||||
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
|
||||
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
|
||||
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
|
||||
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
||||
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
||||
|
||||
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
|
||||
|
||||
params = {
|
||||
**input_dict,
|
||||
}
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError("embedding is not supported for watsonx")
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._get_openai_client().completions.create(**params) # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", False):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
|
||||
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
||||
# watsonx.ai sometimes adds usage data to the stream
|
||||
include_usage = False
|
||||
if params.get("stream_options", None):
|
||||
include_usage = params["stream_options"].get("include_usage", False)
|
||||
stream = await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
seen_finish_reason = False
|
||||
async for chunk in stream:
|
||||
# Final usage chunk with no choices that the user didn't request, so discard
|
||||
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
||||
break
|
||||
yield chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.finish_reason:
|
||||
seen_finish_reason = True
|
||||
break
|
||||
|
|
@ -36,7 +36,6 @@ import os
|
|||
|
||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
|
||||
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
||||
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
||||
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
||||
|
|
@ -125,6 +124,21 @@ client.post_training.job.cancel(job_uuid="your-job-id")
|
|||
|
||||
### Inference with the fine-tuned model
|
||||
|
||||
#### 1. Register the model
|
||||
|
||||
```python
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
|
||||
client.models.register(
|
||||
model_id="test-example-model@v1",
|
||||
provider_id="nvidia",
|
||||
provider_model_id="test-example-model@v1",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
```
|
||||
|
||||
#### 2. Inference with the fine-tuned model
|
||||
|
||||
```python
|
||||
response = client.inference.completion(
|
||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||
|
|
|
|||
|
|
@ -67,13 +67,18 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||
# TODO: filter by available models based on /config endpoint
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
||||
self.customizer_url = config.customizer_url
|
||||
self.session = None
|
||||
|
||||
self.customizer_url = config.customizer_url
|
||||
if not self.customizer_url:
|
||||
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
||||
self.customizer_url = "http://nemo.test"
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
||||
return self.session
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
|
|
@ -94,8 +99,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
if json and "Content-Type" not in request_headers:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
|
||||
session = await self._get_session()
|
||||
for _ in range(self.config.max_retries):
|
||||
async with self.session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||
if response.status >= 400:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"API request failed: {error_data}")
|
||||
|
|
@ -122,8 +128,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
jobs = []
|
||||
for job in response.get("data", []):
|
||||
job_id = job.pop("id")
|
||||
job_status = job.pop("status", "unknown").lower()
|
||||
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
|
||||
job_status = job.pop("status", "scheduled").lower()
|
||||
mapped_status = STATUS_MAPPING.get(job_status, "scheduled")
|
||||
|
||||
# Convert string timestamps to datetime objects
|
||||
created_at = (
|
||||
|
|
@ -177,7 +183,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
)
|
||||
|
||||
api_status = response.pop("status").lower()
|
||||
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
|
||||
mapped_status = STATUS_MAPPING.get(api_status, "scheduled")
|
||||
|
||||
return NvidiaPostTrainingJobStatusResponse(
|
||||
status=JobStatus(mapped_status),
|
||||
|
|
@ -239,6 +245,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
Supported models:
|
||||
- meta/llama-3.1-8b-instruct
|
||||
- meta/llama-3.2-1b-instruct
|
||||
|
||||
Supported algorithm configs:
|
||||
- LoRA, SFT
|
||||
|
|
@ -284,10 +291,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
- LoRA config:
|
||||
## NeMo customizer specific LoRA parameters
|
||||
- adapter_dim: int - Adapter dimension
|
||||
Default: 8 (supports powers of 2)
|
||||
- adapter_dropout: float - Adapter dropout
|
||||
Default: None (0.0-1.0)
|
||||
- alpha: int - Scaling factor for the LoRA update
|
||||
Default: 16
|
||||
Note:
|
||||
|
|
@ -297,7 +300,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
User is informed about unsupported parameters via warnings.
|
||||
"""
|
||||
# Map model to nvidia model name
|
||||
# ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models
|
||||
# See `_MODEL_ENTRIES` for supported models
|
||||
nvidia_model = self.get_provider_model_id(model)
|
||||
|
||||
# Check for unsupported method parameters
|
||||
|
|
@ -330,7 +333,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
},
|
||||
"data_config": {"dataset_id", "batch_size"},
|
||||
"optimizer_config": {"lr", "weight_decay"},
|
||||
"lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"},
|
||||
"lora_config": {"type", "alpha"},
|
||||
}
|
||||
|
||||
# Validate all parameters at once
|
||||
|
|
@ -389,16 +392,10 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
# Handle LoRA-specific configuration
|
||||
if algorithm_config:
|
||||
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
|
||||
if algorithm_config.type == "LoRA":
|
||||
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
||||
job_config["hyperparameters"]["lora"] = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"adapter_dim": algorithm_config.get("adapter_dim"),
|
||||
"alpha": algorithm_config.get("alpha"),
|
||||
"adapter_dropout": algorithm_config.get("adapter_dropout"),
|
||||
}.items()
|
||||
if v is not None
|
||||
k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
||||
|
|
|
|||
|
|
@ -524,11 +524,26 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
else:
|
||||
content = [await _convert_content(message.content)]
|
||||
|
||||
return {
|
||||
result = {
|
||||
"role": message.role,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
result["tool_calls"] = []
|
||||
for tc in message.tool_calls:
|
||||
result["tool_calls"].append(
|
||||
{
|
||||
"id": tc.call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.tool_name,
|
||||
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
|
||||
},
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class UnparseableToolCall(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
|||
SystemDefaultGenerator,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||
)
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
|
@ -306,10 +309,11 @@ def chat_completion_request_to_messages(
|
|||
elif model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
):
|
||||
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_2(request)
|
||||
# llama3.2, llama3.3 follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
|
||||
elif model.model_family == ModelFamily.llama4:
|
||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
|
||||
else:
|
||||
messages = request.messages
|
||||
|
||||
|
|
@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1(
|
|||
return messages
|
||||
|
||||
|
||||
def augment_messages_for_tools_llama_3_2(
|
||||
def augment_messages_for_tools_llama(
|
||||
request: ChatCompletionRequest,
|
||||
custom_tool_prompt_generator,
|
||||
) -> List[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
|
|
@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2(
|
|||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||
system_prompt = existing_system_message.content
|
||||
|
||||
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
||||
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
|
|
|||
|
|
@ -394,12 +394,10 @@
|
|||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
|
|
@ -411,7 +409,6 @@
|
|||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
|
|
@ -419,7 +416,6 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"ollama": [
|
||||
|
|
@ -759,5 +755,41 @@
|
|||
"vllm",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"watsonx": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"ibm_watson_machine_learning",
|
||||
"langdetect",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ LLAMA_STACK_PORT=8321
|
|||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-{{ name }} \
|
||||
|
|
@ -82,6 +83,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
|||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-{{ name }} \
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
version: '2'
|
||||
distribution_spec:
|
||||
description: Use NVIDIA NIM for running LLM inference and safety
|
||||
description: Use NVIDIA NIM for running LLM inference, evaluation and safety
|
||||
providers:
|
||||
inference:
|
||||
- remote::nvidia
|
||||
|
|
@ -13,7 +13,7 @@ distribution_spec:
|
|||
telemetry:
|
||||
- inline::meta-reference
|
||||
eval:
|
||||
- inline::meta-reference
|
||||
- remote::nvidia
|
||||
post_training:
|
||||
- remote::nvidia
|
||||
datasetio:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
|
||||
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||
|
|
@ -20,7 +21,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"safety": ["remote::nvidia"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"eval": ["inline::meta-reference"],
|
||||
"eval": ["remote::nvidia"],
|
||||
"post_training": ["remote::nvidia"],
|
||||
"datasetio": ["inline::localfs"],
|
||||
"scoring": ["inline::basic"],
|
||||
|
|
@ -37,6 +38,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_type="remote::nvidia",
|
||||
config=NVIDIASafetyConfig.sample_run_config(),
|
||||
)
|
||||
eval_provider = Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
config=NVIDIAEvalConfig.sample_run_config(),
|
||||
)
|
||||
inference_model = ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="nvidia",
|
||||
|
|
@ -60,7 +66,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
return DistributionTemplate(
|
||||
name="nvidia",
|
||||
distro_type="self_hosted",
|
||||
description="Use NVIDIA NIM for running LLM inference and safety",
|
||||
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
|
|
@ -69,6 +75,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"eval": [eval_provider],
|
||||
},
|
||||
default_models=default_models,
|
||||
default_tool_groups=default_tool_groups,
|
||||
|
|
@ -78,7 +85,8 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"inference": [
|
||||
inference_provider,
|
||||
safety_provider,
|
||||
]
|
||||
],
|
||||
"eval": [eval_provider],
|
||||
},
|
||||
default_models=[inference_model, safety_model],
|
||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
||||
|
|
@ -90,19 +98,15 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"",
|
||||
"NVIDIA API Key",
|
||||
),
|
||||
## Nemo Customizer related variables
|
||||
"NVIDIA_USER_ID": (
|
||||
"llama-stack-user",
|
||||
"NVIDIA User ID",
|
||||
"NVIDIA_APPEND_API_VERSION": (
|
||||
"True",
|
||||
"Whether to append the API version to the base_url",
|
||||
),
|
||||
## Nemo Customizer related variables
|
||||
"NVIDIA_DATASET_NAMESPACE": (
|
||||
"default",
|
||||
"NVIDIA Dataset Namespace",
|
||||
),
|
||||
"NVIDIA_ACCESS_POLICIES": (
|
||||
"{}",
|
||||
"NVIDIA Access Policies",
|
||||
),
|
||||
"NVIDIA_PROJECT_ID": (
|
||||
"test-project",
|
||||
"NVIDIA Project ID",
|
||||
|
|
@ -119,6 +123,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"http://0.0.0.0:7331",
|
||||
"URL for the NeMo Guardrails Service",
|
||||
),
|
||||
"NVIDIA_EVALUATOR_URL": (
|
||||
"http://0.0.0.0:7331",
|
||||
"URL for the NeMo Evaluator Service",
|
||||
),
|
||||
"INFERENCE_MODEL": (
|
||||
"Llama3.1-8B-Instruct",
|
||||
"Inference model",
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ providers:
|
|||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
||||
api_key: ${env.NVIDIA_API_KEY:}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True}
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
|
|
@ -53,13 +54,10 @@ providers:
|
|||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
||||
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
||||
post_training:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ providers:
|
|||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
||||
api_key: ${env.NVIDIA_API_KEY:}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
|
@ -48,13 +49,10 @@ providers:
|
|||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
||||
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
||||
post_training:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
|
|
|
|||
7
llama_stack/templates/watsonx/__init__.py
Normal file
7
llama_stack/templates/watsonx/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .watsonx import get_distribution_template # noqa: F401
|
||||
30
llama_stack/templates/watsonx/build.yaml
Normal file
30
llama_stack/templates/watsonx/build.yaml
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
version: '2'
|
||||
distribution_spec:
|
||||
description: Use watsonx for running LLM inference
|
||||
providers:
|
||||
inference:
|
||||
- remote::watsonx
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
eval:
|
||||
- inline::meta-reference
|
||||
datasetio:
|
||||
- remote::huggingface
|
||||
- inline::localfs
|
||||
scoring:
|
||||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
74
llama_stack/templates/watsonx/doc_template.md
Normal file
74
llama_stack/templates/watsonx/doc_template.md
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# watsonx Distribution
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 2
|
||||
:hidden:
|
||||
|
||||
self
|
||||
```
|
||||
|
||||
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||
|
||||
{{ providers_table }}
|
||||
|
||||
{% if run_config_env_vars %}
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if default_models %}
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
{% for model in default_models %}
|
||||
- `{{ model.model_id }} {{ model.doc_string }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
||||
Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key).
|
||||
|
||||
|
||||
## Running Llama Stack with watsonx
|
||||
|
||||
You can do this via Conda (build code), venv or Docker which has a pre-built image.
|
||||
|
||||
### Via Docker
|
||||
|
||||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
docker run \
|
||||
-it \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
||||
```bash
|
||||
llama stack build --template watsonx --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID
|
||||
```
|
||||
210
llama_stack/templates/watsonx/run.yaml
Normal file
210
llama_stack/templates/watsonx/run.yaml
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
version: '2'
|
||||
image_name: watsonx
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: watsonx
|
||||
provider_type: remote::watsonx
|
||||
config:
|
||||
url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}
|
||||
api_key: ${env.WATSONX_API_KEY:}
|
||||
project_id: ${env.WATSONX_PROJECT_ID:}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/watsonx/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-3-70b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-2-13b-chat
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-2-13b-chat
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-2-13b
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-2-13b-chat
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-1-70b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-1-8b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-1b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-3b-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-guard-3-11b-vision
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
server:
|
||||
port: 8321
|
||||
90
llama_stack/templates/watsonx/watsonx.py
Normal file
90
llama_stack/templates/watsonx/watsonx.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
|
||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::watsonx"],
|
||||
"vector_io": ["inline::faiss"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::rag-runtime",
|
||||
"remote::model-context-protocol",
|
||||
],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
provider_id="watsonx",
|
||||
provider_type="remote::watsonx",
|
||||
config=WatsonXConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
available_models = {
|
||||
"watsonx": MODEL_ENTRIES,
|
||||
}
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
|
||||
default_models = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name="watsonx",
|
||||
distro_type="remote_hosted",
|
||||
description="Use watsonx for running LLM inference",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
},
|
||||
default_models=default_models,
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMASTACK_PORT": (
|
||||
"5001",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"WATSONX_API_KEY": (
|
||||
"",
|
||||
"watsonx API Key",
|
||||
),
|
||||
"WATSONX_PROJECT_ID": (
|
||||
"",
|
||||
"watsonx Project ID",
|
||||
),
|
||||
},
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue