[#432] Add Groq Provider - tool calls (#630)

# What does this PR do?

Contributes to issue #432

- Adds tool calls to Groq provider
- Enables tool call integration tests

### PR Train

- https://github.com/meta-llama/llama-stack/pull/609 
- https://github.com/meta-llama/llama-stack/pull/630 👈

## Test Plan
Environment:

```shell
export GROQ_API_KEY=<api-key>

# build.yaml and run.yaml files
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/build.yaml
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/run.yaml

# Create environment if not already
conda create --prefix ./envs python=3.10
conda activate ./envs

# Build
pip install -e . && llama stack build --config ./build.yaml --image-type conda

# Activate built environment
conda activate llamastack-groq
```

<details>
<summary>Unit tests</summary>

```shell
# Setup
conda activate llamastack-groq
pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py -vv -k groq -s

# Result
llama_stack/providers/tests/inference/groq/test_groq_utils.py .....................

======================================== 21 passed, 1 warning in 0.05s ========================================
```
</details>

<details>
<summary>Integration tests</summary>

```shell
# Run
conda activate llamastack-groq
pytest llama_stack/providers/tests/inference/test_text_inference.py -k groq -s

# Result
llama_stack/providers/tests/inference/test_text_inference.py .sss.s.ss.sss.s...

========================== 8 passed, 10 skipped, 180 deselected, 7 warnings in 2.73s ==========================
```
</details>

<details>
<summary>Manual</summary>

```bash
llama stack run ./run.yaml --port 5001
```

Via this Jupyter notebook:
9165502582/hello.ipynb
</details>

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [x] Updated relevant documentation. (no relevant documentation it
seems)
- [x] Wrote necessary unit or integration tests.
This commit is contained in:
Aidan Do 2025-01-14 13:17:38 +11:00 committed by GitHub
parent ace8dd6087
commit fdcc74fda2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 400 additions and 57 deletions

View file

@ -375,13 +375,13 @@ class TestInference:
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type in ("remote::groq",):
pytest.skip(
provider.__provider_spec__.provider_type
+ " doesn't support tool calling yet"
)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
@ -422,11 +422,12 @@ class TestInference:
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type in ("remote::groq",):
pytest.skip(
provider.__provider_spec__.provider_type
+ " doesn't support tool calling yet"
)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
messages = sample_messages + [
UserMessage(
@ -444,7 +445,6 @@ class TestInference:
**common_params,
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response