llama-stack/llama_stack/providers/remote/inference/groq/groq.py
Sébastien Han 657f24b964
chore: add missing ToolConfig import in groq.py (#983)
# What does this PR do?

Imported `ToolConfig` from the `llama_stack.apis.inference` module to
resolve missing reference and ensure proper functionality within the
`groq.py` file.

Signed-off-by: Sébastien Han <seb@redhat.com>


## Test Plan

Without the change, pytest will run with the following error:

```
uv run pytest -v -s -k "ollama" llama_stack/providers/tests/
/Users/leseb/Documents/AI/llama-stack/.venv/lib/python3.13/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
============================================ test session starts =============================================
platform darwin -- Python 3.13.1, pytest-8.3.4, pluggy-1.5.0 -- /Users/leseb/Documents/AI/llama-stack/.venv/bin/python3
cachedir: .pytest_cache
metadata: {'Python': '3.13.1', 'Platform': 'macOS-15.3-arm64-arm-64bit-Mach-O', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0', 'nbval': '0.11.0'}}
rootdir: /Users/leseb/Documents/AI/llama-stack
configfile: pyproject.toml
plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0, nbval-0.11.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None
collected 379 items / 1 error / 349 deselected / 30 selected                                                 

=================================================== ERRORS ===================================================
__________________ ERROR collecting llama_stack/providers/tests/inference/groq/test_init.py __________________
llama_stack/providers/tests/inference/groq/test_init.py:11: in <module>
    from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
llama_stack/providers/remote/inference/groq/groq.py:72: in <module>
    class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):
llama_stack/providers/remote/inference/groq/groq.py:102: in GroqInferenceAdapter
    tool_config: Optional[ToolConfig] = None,
E   NameError: name 'ToolConfig' is not defined
========================================== short test summary info ===========================================
ERROR llama_stack/providers/tests/inference/groq/test_init.py - NameError: name 'ToolConfig' is not defined
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
=============================== 349 deselected, 22 warnings, 1 error in 0.28s ================================
```

With the change the test continues to run and fails with a different
error:

```
uv run pytest -v -s llama_stack/providers/tests/
/Users/leseb/Documents/AI/llama-stack/.venv/lib/python3.13/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
============================================ test session starts =============================================
platform darwin -- Python 3.13.1, pytest-8.3.4, pluggy-1.5.0 -- /Users/leseb/Documents/AI/llama-stack/.venv/bin/python3
cachedir: .pytest_cache
metadata: {'Python': '3.13.1', 'Platform': 'macOS-15.3-arm64-arm-64bit-Mach-O', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0', 'nbval': '0.11.0'}}
rootdir: /Users/leseb/Documents/AI/llama-stack
configfile: pyproject.toml
plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0, nbval-0.11.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None
collected 342 items / 1 error                                                                                

=================================================== ERRORS ===================================================
______________ ERROR collecting llama_stack/providers/tests/inference/test_vision_inference.py _______________
llama_stack/providers/tests/inference/test_vision_inference.py:29: in <module>
    class TestVisionModelInference:
llama_stack/providers/tests/inference/test_vision_inference.py:35: in TestVisionModelInference
    ImageContentItem(image=dict(data=PASTA_IMAGE)),
E   pydantic_core._pydantic_core.ValidationError: 1 validation error for ImageContentItem
E   image.data
E     Input should be a valid string, unable to parse raw data as a unicode string [type=string_unicode, input_value=b'\xff\xd8\xff\xe0\x00\x1...0\xe6\x9f5\xb5?\xff\xd9', input_type=bytes]
E       For further information visit https://errors.pydantic.dev/2.10/v/string_unicode
========================================== short test summary info ===========================================
ERROR llama_stack/providers/tests/inference/test_vision_inference.py - pydantic_core._pydantic_core.ValidationError: 1 validation error for ImageContentItem
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
======================================= 22 warnings, 1 error in 0.25s ========================================
```

Which is fixed in https://github.com/meta-llama/llama-stack/pull/1003.

## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-02-07 09:35:00 -08:00

156 lines
5.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from typing import AsyncIterator, List, Optional, Union
import groq
from groq import Groq
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
from llama_models.sku_list import CoreModelId
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
ToolChoice,
ToolConfig,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
build_model_alias_with_just_provider_model_id,
ModelRegistryHelper,
)
from .groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
convert_chat_completion_response_stream,
)
_MODEL_ALIASES = [
build_model_alias(
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_model_alias(
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
# Preview models aren't recommended for production use, but we include this one
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_model_alias(
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
]
class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):
_config: GroqConfig
def __init__(self, config: GroqConfig):
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
self._config = config
def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
# Groq doesn't support non-chat completion as of time of writing
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
model_id = self.get_provider_model_id(model_id)
if model_id == "llama-3.2-3b-preview":
warnings.warn(
"Groq only contains a preview version for llama-3.2-3b-instruct. "
"Preview models aren't recommended for production use. "
"They can be discontinued on short notice."
)
request = convert_chat_completion_request(
request=ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
)
try:
response = self._get_client().chat.completions.create(**request)
except groq.BadRequestError as e:
if e.body.get("error", {}).get("code") == "tool_use_failed":
# For smaller models, Groq may fail to call a tool even when the request is well formed
raise ValueError("Groq failed to call a tool", e.body.get("error", {})) from e
else:
raise e
if stream:
return convert_chat_completion_response_stream(response)
else:
return convert_chat_completion_response(response)
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()
def _get_client(self) -> Groq:
if self._config.api_key is not None:
return Groq(api_key=self._config.api_key)
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.groq_api_key:
raise ValueError(
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
)
return Groq(api_key=provider_data.groq_api_key)