forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Imported `ToolConfig` from the `llama_stack.apis.inference` module to resolve missing reference and ensure proper functionality within the `groq.py` file. Signed-off-by: Sébastien Han <seb@redhat.com> ## Test Plan Without the change, pytest will run with the following error: ``` uv run pytest -v -s -k "ollama" llama_stack/providers/tests/ /Users/leseb/Documents/AI/llama-stack/.venv/lib/python3.13/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ============================================ test session starts ============================================= platform darwin -- Python 3.13.1, pytest-8.3.4, pluggy-1.5.0 -- /Users/leseb/Documents/AI/llama-stack/.venv/bin/python3 cachedir: .pytest_cache metadata: {'Python': '3.13.1', 'Platform': 'macOS-15.3-arm64-arm-64bit-Mach-O', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0', 'nbval': '0.11.0'}} rootdir: /Users/leseb/Documents/AI/llama-stack configfile: pyproject.toml plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0, nbval-0.11.0 asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None collected 379 items / 1 error / 349 deselected / 30 selected =================================================== ERRORS =================================================== __________________ ERROR collecting llama_stack/providers/tests/inference/groq/test_init.py __________________ llama_stack/providers/tests/inference/groq/test_init.py:11: in <module> from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter llama_stack/providers/remote/inference/groq/groq.py:72: in <module> class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData): llama_stack/providers/remote/inference/groq/groq.py:102: in GroqInferenceAdapter tool_config: Optional[ToolConfig] = None, E NameError: name 'ToolConfig' is not defined ========================================== short test summary info =========================================== ERROR llama_stack/providers/tests/inference/groq/test_init.py - NameError: name 'ToolConfig' is not defined !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! =============================== 349 deselected, 22 warnings, 1 error in 0.28s ================================ ``` With the change the test continues to run and fails with a different error: ``` uv run pytest -v -s llama_stack/providers/tests/ /Users/leseb/Documents/AI/llama-stack/.venv/lib/python3.13/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ============================================ test session starts ============================================= platform darwin -- Python 3.13.1, pytest-8.3.4, pluggy-1.5.0 -- /Users/leseb/Documents/AI/llama-stack/.venv/bin/python3 cachedir: .pytest_cache metadata: {'Python': '3.13.1', 'Platform': 'macOS-15.3-arm64-arm-64bit-Mach-O', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0', 'nbval': '0.11.0'}} rootdir: /Users/leseb/Documents/AI/llama-stack configfile: pyproject.toml plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0, nbval-0.11.0 asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None collected 342 items / 1 error =================================================== ERRORS =================================================== ______________ ERROR collecting llama_stack/providers/tests/inference/test_vision_inference.py _______________ llama_stack/providers/tests/inference/test_vision_inference.py:29: in <module> class TestVisionModelInference: llama_stack/providers/tests/inference/test_vision_inference.py:35: in TestVisionModelInference ImageContentItem(image=dict(data=PASTA_IMAGE)), E pydantic_core._pydantic_core.ValidationError: 1 validation error for ImageContentItem E image.data E Input should be a valid string, unable to parse raw data as a unicode string [type=string_unicode, input_value=b'\xff\xd8\xff\xe0\x00\x1...0\xe6\x9f5\xb5?\xff\xd9', input_type=bytes] E For further information visit https://errors.pydantic.dev/2.10/v/string_unicode ========================================== short test summary info =========================================== ERROR llama_stack/providers/tests/inference/test_vision_inference.py - pydantic_core._pydantic_core.ValidationError: 1 validation error for ImageContentItem !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ======================================= 22 warnings, 1 error in 0.25s ======================================== ``` Which is fixed in https://github.com/meta-llama/llama-stack/pull/1003. ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. Signed-off-by: Sébastien Han <seb@redhat.com>
156 lines
5.5 KiB
Python
156 lines
5.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import warnings
|
|
from typing import AsyncIterator, List, Optional, Union
|
|
|
|
import groq
|
|
from groq import Groq
|
|
from llama_models.datatypes import SamplingParams
|
|
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
|
from llama_models.sku_list import CoreModelId
|
|
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
Inference,
|
|
InterleavedContent,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
)
|
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
build_model_alias,
|
|
build_model_alias_with_just_provider_model_id,
|
|
ModelRegistryHelper,
|
|
)
|
|
|
|
from .groq_utils import (
|
|
convert_chat_completion_request,
|
|
convert_chat_completion_response,
|
|
convert_chat_completion_response_stream,
|
|
)
|
|
|
|
_MODEL_ALIASES = [
|
|
build_model_alias(
|
|
"llama3-8b-8192",
|
|
CoreModelId.llama3_1_8b_instruct.value,
|
|
),
|
|
build_model_alias_with_just_provider_model_id(
|
|
"llama-3.1-8b-instant",
|
|
CoreModelId.llama3_1_8b_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"llama3-70b-8192",
|
|
CoreModelId.llama3_70b_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"llama-3.3-70b-versatile",
|
|
CoreModelId.llama3_3_70b_instruct.value,
|
|
),
|
|
# Groq only contains a preview version for llama-3.2-3b
|
|
# Preview models aren't recommended for production use, but we include this one
|
|
# to pass the test fixture
|
|
# TODO(aidand): Replace this with a stable model once Groq supports it
|
|
build_model_alias(
|
|
"llama-3.2-3b-preview",
|
|
CoreModelId.llama3_2_3b_instruct.value,
|
|
),
|
|
]
|
|
|
|
|
|
class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):
|
|
_config: GroqConfig
|
|
|
|
def __init__(self, config: GroqConfig):
|
|
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
|
self._config = config
|
|
|
|
def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
|
# Groq doesn't support non-chat completion as of time of writing
|
|
raise NotImplementedError()
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
|
model_id = self.get_provider_model_id(model_id)
|
|
if model_id == "llama-3.2-3b-preview":
|
|
warnings.warn(
|
|
"Groq only contains a preview version for llama-3.2-3b-instruct. "
|
|
"Preview models aren't recommended for production use. "
|
|
"They can be discontinued on short notice."
|
|
)
|
|
|
|
request = convert_chat_completion_request(
|
|
request=ChatCompletionRequest(
|
|
model=model_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
tool_config=tool_config,
|
|
)
|
|
)
|
|
|
|
try:
|
|
response = self._get_client().chat.completions.create(**request)
|
|
except groq.BadRequestError as e:
|
|
if e.body.get("error", {}).get("code") == "tool_use_failed":
|
|
# For smaller models, Groq may fail to call a tool even when the request is well formed
|
|
raise ValueError("Groq failed to call a tool", e.body.get("error", {})) from e
|
|
else:
|
|
raise e
|
|
|
|
if stream:
|
|
return convert_chat_completion_response_stream(response)
|
|
else:
|
|
return convert_chat_completion_response(response)
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[InterleavedContent],
|
|
) -> EmbeddingsResponse:
|
|
raise NotImplementedError()
|
|
|
|
def _get_client(self) -> Groq:
|
|
if self._config.api_key is not None:
|
|
return Groq(api_key=self._config.api_key)
|
|
else:
|
|
provider_data = self.get_request_provider_data()
|
|
if provider_data is None or not provider_data.groq_api_key:
|
|
raise ValueError(
|
|
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
|
|
)
|
|
return Groq(api_key=provider_data.groq_api_key)
|