Merge remote-tracking branch 'mattf/add-nvidia-inference-adapter' into cdgamarose/add_nvidia_distro

This commit is contained in:
Chantal D Gama Rose 2024-11-20 23:06:31 +00:00
commit 3b5ea74267
28 changed files with 432 additions and 483 deletions

View file

@ -86,10 +86,13 @@ class PhotogenTool(SingleMessageBuiltinTool):
class SearchTool(SingleMessageBuiltinTool):
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
self.api_key = api_key
self.engine_type = engine
if engine == SearchEngineType.bing:
self.engine = BingSearch(api_key, **kwargs)
elif engine == SearchEngineType.brave:
self.engine = BraveSearch(api_key, **kwargs)
elif engine == SearchEngineType.tavily:
self.engine = TavilySearch(api_key, **kwargs)
else:
raise ValueError(f"Unknown search engine: {engine}")
@ -257,6 +260,21 @@ class BraveSearch:
return {"query": query, "top_k": clean_response}
class TavilySearch:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
async def search(self, query: str) -> str:
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": self.api_key, "query": query},
)
return json.dumps(self._clean_tavily_response(response.json()))
def _clean_tavily_response(self, search_response, top_k=3):
return {"query": search_response["query"], "top_k": search_response["results"]}
class WolframAlphaTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None:
self.api_key = api_key

View file

@ -50,11 +50,11 @@ MODEL_ALIASES = [
),
build_model_alias(
"fireworks/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-11b-vision-instruct",
@ -214,10 +214,10 @@ class FireworksInferenceAdapter(
async def _to_async_generator():
if "messages" in params:
stream = await self._get_client().chat.completions.acreate(**params)
stream = self._get_client().chat.completions.acreate(**params)
else:
stream = self._get_client().completion.create(**params)
for chunk in stream:
stream = self._get_client().completion.acreate(**params)
async for chunk in stream:
yield chunk
stream = _to_async_generator()

View file

@ -17,7 +17,7 @@ class NVIDIAConfig(BaseModel):
Configuration for the NVIDIA NIM inference endpoint.
Attributes:
base_url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
There are two ways to access NVIDIA NIMs -
@ -30,11 +30,11 @@ class NVIDIAConfig(BaseModel):
By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code.
If you are using a self-hosted NVIDIA NIM, you can set the base_url to the
If you are using a self-hosted NVIDIA NIM, you can set the url to the
URL of your running NVIDIA NIM and do not need to set the api_key.
"""
base_url: str = Field(
url: str = Field(
default="https://integrate.api.nvidia.com",
description="A base url for accessing the NVIDIA NIM",
)
@ -49,7 +49,7 @@ class NVIDIAConfig(BaseModel):
@property
def is_hosted(self) -> bool:
return "integrate.api.nvidia.com" in self.base_url
return "integrate.api.nvidia.com" in self.url
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:

View file

@ -89,7 +89,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
print(f"Initializing NVIDIAInferenceAdapter({config.base_url})...")
print(f"Initializing NVIDIAInferenceAdapter({config.url})...")
if config.is_hosted:
if not config.api_key:
@ -110,7 +110,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self._config = config
# make sure the client lives longer than any async calls
self._client = AsyncOpenAI(
base_url=f"{self._config.base_url}/v1",
base_url=f"{self._config.url}/v1",
api_key=self._config.api_key or "NO KEY",
timeout=self._config.timeout,
)
@ -172,7 +172,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
response = await self._client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(
f"Failed to connect to NVIDIA NIM at {self._config.base_url}: {e}"
f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
) from e
if stream:

View file

@ -40,7 +40,7 @@ async def check_health(config: NVIDIAConfig) -> None:
if not config.is_hosted:
print("Checking NVIDIA NIM health...")
try:
is_live, is_ready = await _get_health(config.base_url)
is_live, is_ready = await _get_health(config.url)
if not is_live:
raise ConnectionError("NVIDIA NIM is not running")
if not is_ready:

View file

@ -264,6 +264,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
print(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]

View file

@ -53,6 +53,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None
async def initialize(self) -> None:
print(f"Initializing VLLM client with base_url={self.config.url}")
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def shutdown(self) -> None:

View file

@ -68,6 +68,73 @@ def query_attachment_messages():
]
async def create_agent_turn_with_search_tool(
agents_stack: Dict[str, object],
search_query_messages: List[object],
common_params: Dict[str, str],
search_tool_definition: SearchToolDefinition,
) -> None:
"""
Create an agent turn with a search tool.
Args:
agents_stack (Dict[str, object]): The agents stack.
search_query_messages (List[object]): The search query messages.
common_params (Dict[str, str]): The common parameters.
search_tool_definition (SearchToolDefinition): The search tool definition.
"""
# Create an agent with the search tool
agent_config = AgentConfig(
**{
**common_params,
"tools": [search_tool_definition],
}
)
agent_id, session_id = await create_agent_session(
agents_stack.impls[Api.agents], agent_config
)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
**turn_request
)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
class TestAgents:
@pytest.mark.asyncio
async def test_agent_turns_with_safety(
@ -215,63 +282,34 @@ class TestAgents:
async def test_create_agent_turn_with_brave_search(
self, agents_stack, search_query_messages, common_params
):
agents_impl = agents_stack.impls[Api.agents]
if "BRAVE_SEARCH_API_KEY" not in os.environ:
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
# Create an agent with Brave search tool
agent_config = AgentConfig(
**{
**common_params,
"tools": [
SearchToolDefinition(
type=AgentTool.brave_search.value,
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
engine=SearchEngineType.brave,
)
],
}
search_tool_definition = SearchToolDefinition(
type=AgentTool.brave_search.value,
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
engine=SearchEngineType.brave,
)
await create_agent_turn_with_search_tool(
agents_stack, search_query_messages, common_params, search_tool_definition
)
agent_id, session_id = await create_agent_session(agents_impl, agent_config)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
@pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params
):
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
search_tool_definition = SearchToolDefinition(
type=AgentTool.brave_search.value, # place holder only
api_key=os.environ["TAVILY_SEARCH_API_KEY"],
engine=SearchEngineType.tavily,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
await create_agent_turn_with_search_tool(
agents_stack, search_query_messages, common_params, search_tool_definition
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type
== StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
def check_event_types(turn_response):
event_types = [chunk.event.payload.event_type for chunk in turn_response]

View file

@ -25,7 +25,11 @@ from .utils import group_chunks
def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
return (
StopReason.end_of_message
if ("Llama3.1" in model or "Llama-3.1" in model)
else StopReason.end_of_turn
)
@pytest.fixture
@ -34,7 +38,7 @@ def common_params(inference_model):
"tool_choice": ToolChoice.auto,
"tool_prompt_format": (
ToolPromptFormat.json
if "Llama3.1" in inference_model
if ("Llama3.1" in inference_model or "Llama-3.1" in inference_model)
else ToolPromptFormat.python_list
),
}