Add tool calls to groq inference adapter

This commit is contained in:
Aidan Do 2024-12-14 20:48:38 +11:00
parent 485476c29a
commit d913fbeafe
4 changed files with 398 additions and 57 deletions

View file

@ -373,13 +373,13 @@ class TestInference:
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type in ("remote::groq",):
pytest.skip(
provider.__provider_spec__.provider_type
+ " doesn't support tool calling yet"
)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
@ -420,11 +420,12 @@ class TestInference:
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type in ("remote::groq",):
pytest.skip(
provider.__provider_spec__.provider_type
+ " doesn't support tool calling yet"
)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
messages = sample_messages + [
UserMessage(
@ -442,7 +443,6 @@ class TestInference:
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response