Add tool calls to groq inference adapter

This commit is contained in:
Aidan Do 2024-12-14 20:48:38 +11:00
parent 485476c29a
commit d913fbeafe
4 changed files with 398 additions and 57 deletions

View file

@ -7,6 +7,7 @@
import warnings
from typing import AsyncIterator, List, Optional, Union
import groq
from groq import Groq
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
@ -124,7 +125,14 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
)
)
response = self._get_client().chat.completions.create(**request)
try:
response = self._get_client().chat.completions.create(**request)
except groq.BadRequestError as e:
if e.body.get("error", {}).get("code") == "tool_use_failed":
# For smaller models, Groq may fail to call a tool even when the request is well formed
raise ValueError("Groq failed to call a tool", e.body.get("error", {}))
else:
raise e
if stream:
return convert_chat_completion_response_stream(response)