mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 15:49:49 +00:00
Merge remote-tracking branch 'mattf/add-nvidia-inference-adapter' into cdgamarose/add_nvidia_distro
This commit is contained in:
commit
3b5ea74267
28 changed files with 432 additions and 483 deletions
|
|
@ -54,6 +54,7 @@ class ToolDefinitionCommon(BaseModel):
|
|||
class SearchEngineType(Enum):
|
||||
bing = "bing"
|
||||
brave = "brave"
|
||||
tavily = "tavily"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -380,6 +380,7 @@ def _hf_download(
|
|||
|
||||
def _meta_download(
|
||||
model: "Model",
|
||||
model_id: str,
|
||||
meta_url: str,
|
||||
info: "LlamaDownloadInfo",
|
||||
max_concurrent_downloads: int,
|
||||
|
|
@ -405,8 +406,15 @@ def _meta_download(
|
|||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
print(f"\nSuccessfully downloaded model to {output_dir}")
|
||||
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
|
||||
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
|
||||
cprint(
|
||||
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
||||
"white",
|
||||
)
|
||||
cprint(
|
||||
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
||||
"yellow",
|
||||
)
|
||||
|
||||
|
||||
class ModelEntry(BaseModel):
|
||||
|
|
@ -512,7 +520,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
)
|
||||
if "llamameta.net" not in meta_url:
|
||||
parser.error("Invalid Meta URL provided")
|
||||
_meta_download(model, meta_url, info, args.max_parallel)
|
||||
_meta_download(model, model_id, meta_url, info, args.max_parallel)
|
||||
|
||||
except Exception as e:
|
||||
parser.error(f"Download failed: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
BUILD_PLATFORM=${BUILD_PLATFORM:-}
|
||||
|
||||
if [ "$#" -lt 4 ]; then
|
||||
echo "Usage: $0 <build_name> <docker_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
|
|
@ -96,7 +97,7 @@ else
|
|||
add_to_docker "RUN pip install fastapi libcst"
|
||||
add_to_docker <<EOF
|
||||
RUN pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
||||
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
||||
EOF
|
||||
else
|
||||
add_to_docker "RUN pip install --no-cache llama-stack"
|
||||
|
|
@ -116,7 +117,6 @@ RUN pip install --no-cache $models_mount
|
|||
EOF
|
||||
fi
|
||||
|
||||
|
||||
add_to_docker <<EOF
|
||||
|
||||
# This would be good in production but for debugging flexibility lets not add it right now
|
||||
|
|
@ -158,7 +158,9 @@ image_tag="$image_name:$version_tag"
|
|||
|
||||
# Detect platform architecture
|
||||
ARCH=$(uname -m)
|
||||
if [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||
if [ -n "$BUILD_PLATFORM" ]; then
|
||||
PLATFORM="--platform $BUILD_PLATFORM"
|
||||
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||
PLATFORM="--platform linux/arm64"
|
||||
elif [ "$ARCH" = "x86_64" ]; then
|
||||
PLATFORM="--platform linux/amd64"
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ import httpx
|
|||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
|
@ -117,7 +119,7 @@ def create_api_client_class(protocol) -> Type:
|
|||
break
|
||||
kwargs[param.name] = args[i]
|
||||
|
||||
url = f"{self.base_url}{webmethod.route}"
|
||||
url = f"{self.base_url}/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
|
||||
def convert(value):
|
||||
if isinstance(value, list):
|
||||
|
|
|
|||
|
|
@ -86,10 +86,13 @@ class PhotogenTool(SingleMessageBuiltinTool):
|
|||
class SearchTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
|
||||
self.api_key = api_key
|
||||
self.engine_type = engine
|
||||
if engine == SearchEngineType.bing:
|
||||
self.engine = BingSearch(api_key, **kwargs)
|
||||
elif engine == SearchEngineType.brave:
|
||||
self.engine = BraveSearch(api_key, **kwargs)
|
||||
elif engine == SearchEngineType.tavily:
|
||||
self.engine = TavilySearch(api_key, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown search engine: {engine}")
|
||||
|
||||
|
|
@ -257,6 +260,21 @@ class BraveSearch:
|
|||
return {"query": query, "top_k": clean_response}
|
||||
|
||||
|
||||
class TavilySearch:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
async def search(self, query: str) -> str:
|
||||
response = requests.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": self.api_key, "query": query},
|
||||
)
|
||||
return json.dumps(self._clean_tavily_response(response.json()))
|
||||
|
||||
def _clean_tavily_response(self, search_response, top_k=3):
|
||||
return {"query": search_response["query"], "top_k": search_response["results"]}
|
||||
|
||||
|
||||
class WolframAlphaTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
|
|
|||
|
|
@ -50,11 +50,11 @@ MODEL_ALIASES = [
|
|||
),
|
||||
build_model_alias(
|
||||
"fireworks/llama-v3p2-1b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"fireworks/llama-v3p2-3b-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"fireworks/llama-v3p2-11b-vision-instruct",
|
||||
|
|
@ -214,10 +214,10 @@ class FireworksInferenceAdapter(
|
|||
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
stream = await self._get_client().chat.completions.acreate(**params)
|
||||
stream = self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
stream = self._get_client().completion.create(**params)
|
||||
for chunk in stream:
|
||||
stream = self._get_client().completion.acreate(**params)
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class NVIDIAConfig(BaseModel):
|
|||
Configuration for the NVIDIA NIM inference endpoint.
|
||||
|
||||
Attributes:
|
||||
base_url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||
api_key (str): The access key for the hosted NIM endpoints
|
||||
|
||||
There are two ways to access NVIDIA NIMs -
|
||||
|
|
@ -30,11 +30,11 @@ class NVIDIAConfig(BaseModel):
|
|||
By default the configuration will attempt to read the NVIDIA_API_KEY environment
|
||||
variable to set the api_key. Please do not put your API key in code.
|
||||
|
||||
If you are using a self-hosted NVIDIA NIM, you can set the base_url to the
|
||||
If you are using a self-hosted NVIDIA NIM, you can set the url to the
|
||||
URL of your running NVIDIA NIM and do not need to set the api_key.
|
||||
"""
|
||||
|
||||
base_url: str = Field(
|
||||
url: str = Field(
|
||||
default="https://integrate.api.nvidia.com",
|
||||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
|
|
@ -49,7 +49,7 @@ class NVIDIAConfig(BaseModel):
|
|||
|
||||
@property
|
||||
def is_hosted(self) -> bool:
|
||||
return "integrate.api.nvidia.com" in self.base_url
|
||||
return "integrate.api.nvidia.com" in self.url
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||
|
||||
print(f"Initializing NVIDIAInferenceAdapter({config.base_url})...")
|
||||
print(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
if config.is_hosted:
|
||||
if not config.api_key:
|
||||
|
|
@ -110,7 +110,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
self._config = config
|
||||
# make sure the client lives longer than any async calls
|
||||
self._client = AsyncOpenAI(
|
||||
base_url=f"{self._config.base_url}/v1",
|
||||
base_url=f"{self._config.url}/v1",
|
||||
api_key=self._config.api_key or "NO KEY",
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
|
@ -172,7 +172,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
response = await self._client.chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to NVIDIA NIM at {self._config.base_url}: {e}"
|
||||
f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
|
||||
) from e
|
||||
|
||||
if stream:
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ async def check_health(config: NVIDIAConfig) -> None:
|
|||
if not config.is_hosted:
|
||||
print("Checking NVIDIA NIM health...")
|
||||
try:
|
||||
is_live, is_ready = await _get_health(config.base_url)
|
||||
is_live, is_ready = await _get_health(config.url)
|
||||
if not is_live:
|
||||
raise ConnectionError("NVIDIA NIM is not running")
|
||||
if not is_ready:
|
||||
|
|
|
|||
|
|
@ -264,6 +264,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
print(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print(f"Initializing VLLM client with base_url={self.config.url}")
|
||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
|
|||
|
|
@ -68,6 +68,73 @@ def query_attachment_messages():
|
|||
]
|
||||
|
||||
|
||||
async def create_agent_turn_with_search_tool(
|
||||
agents_stack: Dict[str, object],
|
||||
search_query_messages: List[object],
|
||||
common_params: Dict[str, str],
|
||||
search_tool_definition: SearchToolDefinition,
|
||||
) -> None:
|
||||
"""
|
||||
Create an agent turn with a search tool.
|
||||
|
||||
Args:
|
||||
agents_stack (Dict[str, object]): The agents stack.
|
||||
search_query_messages (List[object]): The search query messages.
|
||||
common_params (Dict[str, str]): The common parameters.
|
||||
search_tool_definition (SearchToolDefinition): The search tool definition.
|
||||
"""
|
||||
|
||||
# Create an agent with the search tool
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"tools": [search_tool_definition],
|
||||
}
|
||||
)
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_stack.impls[Api.agents], agent_config
|
||||
)
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=search_query_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk
|
||||
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
||||
**turn_request
|
||||
)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
|
||||
check_event_types(turn_response)
|
||||
|
||||
# Check for tool execution events
|
||||
tool_execution_events = [
|
||||
chunk
|
||||
for chunk in turn_response
|
||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
|
||||
]
|
||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||
|
||||
# Check the tool execution details
|
||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
||||
|
||||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turns_with_safety(
|
||||
|
|
@ -215,63 +282,34 @@ class TestAgents:
|
|||
async def test_create_agent_turn_with_brave_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
):
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
|
||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
# Create an agent with Brave search tool
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"tools": [
|
||||
SearchToolDefinition(
|
||||
type=AgentTool.brave_search.value,
|
||||
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
||||
engine=SearchEngineType.brave,
|
||||
)
|
||||
],
|
||||
}
|
||||
search_tool_definition = SearchToolDefinition(
|
||||
type=AgentTool.brave_search.value,
|
||||
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
||||
engine=SearchEngineType.brave,
|
||||
)
|
||||
await create_agent_turn_with_search_tool(
|
||||
agents_stack, search_query_messages, common_params, search_tool_definition
|
||||
)
|
||||
|
||||
agent_id, session_id = await create_agent_session(agents_impl, agent_config)
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=search_query_messages,
|
||||
stream=True,
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_tavily_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
):
|
||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
search_tool_definition = SearchToolDefinition(
|
||||
type=AgentTool.brave_search.value, # place holder only
|
||||
api_key=os.environ["TAVILY_SEARCH_API_KEY"],
|
||||
engine=SearchEngineType.tavily,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
await create_agent_turn_with_search_tool(
|
||||
agents_stack, search_query_messages, common_params, search_tool_definition
|
||||
)
|
||||
|
||||
check_event_types(turn_response)
|
||||
|
||||
# Check for tool execution events
|
||||
tool_execution_events = [
|
||||
chunk
|
||||
for chunk in turn_response
|
||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||
and chunk.event.payload.step_details.step_type
|
||||
== StepType.tool_execution.value
|
||||
]
|
||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||
|
||||
# Check the tool execution details
|
||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
||||
|
||||
def check_event_types(turn_response):
|
||||
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
||||
|
|
|
|||
|
|
@ -25,7 +25,11 @@ from .utils import group_chunks
|
|||
|
||||
|
||||
def get_expected_stop_reason(model: str):
|
||||
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||
return (
|
||||
StopReason.end_of_message
|
||||
if ("Llama3.1" in model or "Llama-3.1" in model)
|
||||
else StopReason.end_of_turn
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -34,7 +38,7 @@ def common_params(inference_model):
|
|||
"tool_choice": ToolChoice.auto,
|
||||
"tool_prompt_format": (
|
||||
ToolPromptFormat.json
|
||||
if "Llama3.1" in inference_model
|
||||
if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model)
|
||||
else ToolPromptFormat.python_list
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import concurrent.futures
|
||||
import importlib
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from functools import partial
|
||||
|
|
@ -14,6 +15,11 @@ from typing import Iterator
|
|||
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
from llama_stack.distribution.build import (
|
||||
get_provider_dependencies,
|
||||
SERVER_DEPENDENCIES,
|
||||
)
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
|
|
@ -67,6 +73,39 @@ def check_for_changes() -> bool:
|
|||
return result.returncode != 0
|
||||
|
||||
|
||||
def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
|
||||
try:
|
||||
module_name = f"llama_stack.templates.{template_dir.name}"
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
if template_func := getattr(module, "get_distribution_template", None):
|
||||
template = template_func()
|
||||
normal_deps, special_deps = get_provider_dependencies(template.providers)
|
||||
# Combine all dependencies in order: normal deps, special deps, server deps
|
||||
all_deps = sorted(list(set(normal_deps + SERVER_DEPENDENCIES))) + sorted(
|
||||
list(set(special_deps))
|
||||
)
|
||||
|
||||
return template.name, all_deps
|
||||
except Exception:
|
||||
return None, []
|
||||
return None, []
|
||||
|
||||
|
||||
def generate_dependencies_file():
|
||||
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
||||
distribution_deps = {}
|
||||
|
||||
for template_dir in find_template_dirs(templates_dir):
|
||||
name, deps = collect_template_dependencies(template_dir)
|
||||
if name:
|
||||
distribution_deps[name] = deps
|
||||
|
||||
deps_file = REPO_ROOT / "distributions" / "dependencies.json"
|
||||
with open(deps_file, "w") as f:
|
||||
json.dump(distribution_deps, f, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
||||
|
||||
|
|
@ -88,6 +127,8 @@ def main():
|
|||
list(executor.map(process_func, template_dirs))
|
||||
progress.update(task, advance=len(template_dirs))
|
||||
|
||||
generate_dependencies_file()
|
||||
|
||||
if check_for_changes():
|
||||
print(
|
||||
"Distribution template changes detected. Please commit the changes.",
|
||||
|
|
|
|||
|
|
@ -57,11 +57,11 @@ models:
|
|||
provider_id: null
|
||||
provider_model_id: fireworks/llama-v3p1-405b-instruct
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: null
|
||||
provider_model_id: fireworks/llama-v3p2-1b-instruct
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: null
|
||||
provider_model_id: fireworks/llama-v3p2-3b-instruct
|
||||
- metadata: {}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ version: '2'
|
|||
name: tgi
|
||||
distribution_spec:
|
||||
description: Use (an external) TGI server for running LLM inference
|
||||
docker_image: llamastack/distribution-tgi:test-0.0.52rc3
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- remote::tgi
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
version: '2'
|
||||
image_name: tgi
|
||||
docker_image: llamastack/distribution-tgi:test-0.0.52rc3
|
||||
docker_image: null
|
||||
conda_env: tgi
|
||||
apis:
|
||||
- agents
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
version: '2'
|
||||
image_name: tgi
|
||||
docker_image: llamastack/distribution-tgi:test-0.0.52rc3
|
||||
docker_image: null
|
||||
conda_env: tgi
|
||||
apis:
|
||||
- agents
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
name="tgi",
|
||||
distro_type="self_hosted",
|
||||
description="Use (an external) TGI server for running LLM inference",
|
||||
docker_image="llamastack/distribution-tgi:test-0.0.52rc3",
|
||||
docker_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
default_models=[inference_model, safety_model],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue