From 928a39d17bc3108ed64954481f067543677434a7 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 27 Feb 2025 13:16:50 -0800 Subject: [PATCH] feat(providers): Groq now uses LiteLLM openai-compat (#1303) Groq has never supported raw completions anyhow. So this makes it easier to switch it to LiteLLM. All our test suite passes. I also updated all the openai-compat providers so they work with api keys passed from headers. `provider_data` ## Test Plan ```bash LLAMA_STACK_CONFIG=groq \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=groq/llama-3.3-70b-versatile --vision-inference-model="" ``` Also tested (openai, anthropic, gemini) providers. No regressions. --- distributions/dependencies.json | 1 + .../distributions/self_hosted_distro/groq.md | 10 +- llama_stack/providers/registry/inference.py | 23 +- .../remote/inference/anthropic/anthropic.py | 11 +- .../remote/inference/anthropic/config.py | 7 + .../remote/inference/gemini/config.py | 7 + .../remote/inference/gemini/gemini.py | 11 +- .../remote/inference/groq/__init__.py | 9 - .../providers/remote/inference/groq/config.py | 11 +- .../providers/remote/inference/groq/groq.py | 130 +--- .../remote/inference/groq/groq_utils.py | 245 -------- .../providers/remote/inference/groq/models.py | 12 +- .../remote/inference/openai/config.py | 7 + .../remote/inference/openai/openai.py | 11 +- .../tests/inference/groq/test_groq_utils.py | 575 ------------------ .../utils/inference/litellm_openai_mixin.py | 21 +- llama_stack/templates/dev/build.yaml | 1 + llama_stack/templates/dev/dev.py | 17 +- llama_stack/templates/dev/run.yaml | 30 + llama_stack/templates/groq/groq.py | 14 +- llama_stack/templates/groq/run.yaml | 10 +- tests/client-sdk/conftest.py | 4 +- .../inference/test_text_inference.py | 2 +- 23 files changed, 165 insertions(+), 1004 deletions(-) delete mode 100644 llama_stack/providers/remote/inference/groq/groq_utils.py delete mode 100644 llama_stack/providers/tests/inference/groq/test_groq_utils.py diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 622cf791b..b147b7df6 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -146,6 +146,7 @@ "fastapi", "fire", "fireworks-ai", + "groq", "httpx", "litellm", "matplotlib", diff --git a/docs/source/distributions/self_hosted_distro/groq.md b/docs/source/distributions/self_hosted_distro/groq.md index 296a5f49b..9fb7b2619 100644 --- a/docs/source/distributions/self_hosted_distro/groq.md +++ b/docs/source/distributions/self_hosted_distro/groq.md @@ -37,11 +37,11 @@ The following environment variables can be configured: The following models are available by default: -- `meta-llama/Llama-3.1-8B-Instruct (llama3-8b-8192)` -- `meta-llama/Llama-3.1-8B-Instruct (llama-3.1-8b-instant)` -- `meta-llama/Llama-3-70B-Instruct (llama3-70b-8192)` -- `meta-llama/Llama-3.3-70B-Instruct (llama-3.3-70b-versatile)` -- `meta-llama/Llama-3.2-3B-Instruct (llama-3.2-3b-preview)` +- `meta-llama/Llama-3.1-8B-Instruct (groq/llama3-8b-8192)` +- `meta-llama/Llama-3.1-8B-Instruct (groq/llama-3.1-8b-instant)` +- `meta-llama/Llama-3-70B-Instruct (groq/llama3-70b-8192)` +- `meta-llama/Llama-3.3-70B-Instruct (groq/llama-3.3-70b-versatile)` +- `meta-llama/Llama-3.2-3B-Instruct (groq/llama-3.2-3b-preview)` ### Prerequisite: API Keys diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 3ba634e9a..95b11deee 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -157,16 +157,6 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", ), ), - remote_provider_spec( - api=Api.inference, - adapter=AdapterSpec( - adapter_type="groq", - pip_packages=["groq"], - module="llama_stack.providers.remote.inference.groq", - config_class="llama_stack.providers.remote.inference.groq.GroqConfig", - provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator", - ), - ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( @@ -214,6 +204,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=["litellm"], module="llama_stack.providers.remote.inference.openai", config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", ), ), remote_provider_spec( @@ -223,6 +214,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=["litellm"], module="llama_stack.providers.remote.inference.anthropic", config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", + provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", ), ), remote_provider_spec( @@ -232,6 +224,17 @@ def available_providers() -> List[ProviderSpec]: pip_packages=["litellm"], module="llama_stack.providers.remote.inference.gemini", config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", + provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", + ), + ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="groq", + pip_packages=["groq"], + module="llama_stack.providers.remote.inference.groq", + config_class="llama_stack.providers.remote.inference.groq.GroqConfig", + provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", ), ), remote_provider_spec( diff --git a/llama_stack/providers/remote/inference/anthropic/anthropic.py b/llama_stack/providers/remote/inference/anthropic/anthropic.py index 2b392b295..fa0a7e10f 100644 --- a/llama_stack/providers/remote/inference/anthropic/anthropic.py +++ b/llama_stack/providers/remote/inference/anthropic/anthropic.py @@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES class AnthropicInferenceAdapter(LiteLLMOpenAIMixin): def __init__(self, config: AnthropicConfig) -> None: - LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES) + LiteLLMOpenAIMixin.__init__( + self, + MODEL_ENTRIES, + api_key_from_config=config.api_key, + provider_data_api_key_field="anthropic_api_key", + ) self.config = config async def initialize(self) -> None: - pass + await super().initialize() async def shutdown(self) -> None: - pass + await super().shutdown() diff --git a/llama_stack/providers/remote/inference/anthropic/config.py b/llama_stack/providers/remote/inference/anthropic/config.py index 00323b1e7..0e9469602 100644 --- a/llama_stack/providers/remote/inference/anthropic/config.py +++ b/llama_stack/providers/remote/inference/anthropic/config.py @@ -11,6 +11,13 @@ from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type +class AnthropicProviderDataValidator(BaseModel): + anthropic_api_key: Optional[str] = Field( + default=None, + description="API key for Anthropic models", + ) + + @json_schema_type class AnthropicConfig(BaseModel): api_key: Optional[str] = Field( diff --git a/llama_stack/providers/remote/inference/gemini/config.py b/llama_stack/providers/remote/inference/gemini/config.py index cce8c756c..30c8d9913 100644 --- a/llama_stack/providers/remote/inference/gemini/config.py +++ b/llama_stack/providers/remote/inference/gemini/config.py @@ -11,6 +11,13 @@ from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type +class GeminiProviderDataValidator(BaseModel): + gemini_api_key: Optional[str] = Field( + default=None, + description="API key for Gemini models", + ) + + @json_schema_type class GeminiConfig(BaseModel): api_key: Optional[str] = Field( diff --git a/llama_stack/providers/remote/inference/gemini/gemini.py b/llama_stack/providers/remote/inference/gemini/gemini.py index b269bc14a..11f6f05ad 100644 --- a/llama_stack/providers/remote/inference/gemini/gemini.py +++ b/llama_stack/providers/remote/inference/gemini/gemini.py @@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES class GeminiInferenceAdapter(LiteLLMOpenAIMixin): def __init__(self, config: GeminiConfig) -> None: - LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES) + LiteLLMOpenAIMixin.__init__( + self, + MODEL_ENTRIES, + api_key_from_config=config.api_key, + provider_data_api_key_field="gemini_api_key", + ) self.config = config async def initialize(self) -> None: - pass + await super().initialize() async def shutdown(self) -> None: - pass + await super().shutdown() diff --git a/llama_stack/providers/remote/inference/groq/__init__.py b/llama_stack/providers/remote/inference/groq/__init__.py index 923c35696..1506e0b06 100644 --- a/llama_stack/providers/remote/inference/groq/__init__.py +++ b/llama_stack/providers/remote/inference/groq/__init__.py @@ -4,23 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel - from llama_stack.apis.inference import Inference from .config import GroqConfig -class GroqProviderDataValidator(BaseModel): - groq_api_key: str - - async def get_adapter_impl(config: GroqConfig, _deps) -> Inference: # import dynamically so the import is used only when it is needed from .groq import GroqInferenceAdapter - if not isinstance(config, GroqConfig): - raise RuntimeError(f"Unexpected config type: {type(config)}") - adapter = GroqInferenceAdapter(config) return adapter diff --git a/llama_stack/providers/remote/inference/groq/config.py b/llama_stack/providers/remote/inference/groq/config.py index 6b221478c..8a1204b0b 100644 --- a/llama_stack/providers/remote/inference/groq/config.py +++ b/llama_stack/providers/remote/inference/groq/config.py @@ -11,6 +11,13 @@ from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type +class GroqProviderDataValidator(BaseModel): + groq_api_key: Optional[str] = Field( + default=None, + description="API key for Groq models", + ) + + @json_schema_type class GroqConfig(BaseModel): api_key: Optional[str] = Field( @@ -25,8 +32,8 @@ class GroqConfig(BaseModel): ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]: return { "url": "https://api.groq.com", - "api_key": "${env.GROQ_API_KEY}", + "api_key": api_key, } diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 2c9fab614..c8789434f 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -4,130 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import warnings -from typing import AsyncIterator, List, Optional, Union - -import groq -from groq import Groq - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseStreamChunk, - CompletionResponse, - CompletionResponseStreamChunk, - EmbeddingsResponse, - EmbeddingTaskType, - Inference, - InterleavedContent, - InterleavedContentItem, - LogProbConfig, - Message, - ResponseFormat, - TextTruncation, - ToolChoice, - ToolConfig, -) -from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.providers.remote.inference.groq.config import GroqConfig -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -from .groq_utils import ( - convert_chat_completion_request, - convert_chat_completion_response, - convert_chat_completion_response_stream, -) -from .models import _MODEL_ENTRIES +from .models import MODEL_ENTRIES -class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData): +class GroqInferenceAdapter(LiteLLMOpenAIMixin): _config: GroqConfig def __init__(self, config: GroqConfig): - ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES) - self._config = config - - def completion( - self, - model_id: str, - content: InterleavedContent, - sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: - # Groq doesn't support non-chat completion as of time of writing - raise NotImplementedError() - - async def chat_completion( - self, - model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: - model_id = self.get_provider_model_id(model_id) - if model_id == "llama-3.2-3b-preview": - warnings.warn( - "Groq only contains a preview version for llama-3.2-3b-instruct. " - "Preview models aren't recommended for production use. " - "They can be discontinued on short notice." - "More details: https://console.groq.com/docs/models" - ) - - request = convert_chat_completion_request( - request=ChatCompletionRequest( - model=model_id, - messages=messages, - sampling_params=sampling_params, - response_format=response_format, - tools=tools, - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) + LiteLLMOpenAIMixin.__init__( + self, + model_entries=MODEL_ENTRIES, + api_key_from_config=config.api_key, + provider_data_api_key_field="groq_api_key", ) + self.config = config - try: - response = self._get_client().chat.completions.create(**request) - except groq.BadRequestError as e: - if e.body.get("error", {}).get("code") == "tool_use_failed": - # For smaller models, Groq may fail to call a tool even when the request is well formed - raise ValueError("Groq failed to call a tool", e.body.get("error", {})) from e - else: - raise e + async def initialize(self): + await super().initialize() - if stream: - return convert_chat_completion_response_stream(response) - else: - return convert_chat_completion_response(response) - - async def embeddings( - self, - model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, - ) -> EmbeddingsResponse: - raise NotImplementedError() - - def _get_client(self) -> Groq: - if self._config.api_key is not None: - return Groq(api_key=self._config.api_key) - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.groq_api_key: - raise ValueError( - 'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "" }' - ) - return Groq(api_key=provider_data.groq_api_key) + async def shutdown(self): + await super().shutdown() diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py deleted file mode 100644 index f1138e789..000000000 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import warnings -from typing import AsyncGenerator, Literal - -from groq import Stream -from groq.types.chat.chat_completion import ChatCompletion -from groq.types.chat.chat_completion_assistant_message_param import ( - ChatCompletionAssistantMessageParam, -) -from groq.types.chat.chat_completion_chunk import ChatCompletionChunk -from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam -from groq.types.chat.chat_completion_system_message_param import ( - ChatCompletionSystemMessageParam, -) -from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam -from groq.types.chat.chat_completion_user_message_param import ( - ChatCompletionUserMessageParam, -) -from groq.types.chat.completion_create_params import CompletionCreateParams -from groq.types.shared.function_definition import FunctionDefinition - -from llama_stack.apis.common.content_types import ( - TextDelta, - ToolCallDelta, - ToolCallParseStatus, -) -from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionMessage, - Message, - StopReason, - ToolCall, - ToolDefinition, - ToolPromptFormat, -) -from llama_stack.models.llama.datatypes import ToolParamDefinition -from llama_stack.providers.utils.inference.openai_compat import ( - UnparseableToolCall, - convert_tool_call, - get_sampling_strategy_options, -) - - -def convert_chat_completion_request( - request: ChatCompletionRequest, -) -> CompletionCreateParams: - """ - Convert a ChatCompletionRequest to a Groq API-compatible dictionary. - Warns client if request contains unsupported features. - """ - - if request.logprobs: - # Groq doesn't support logprobs at the time of writing - warnings.warn("logprobs are not supported yet") - - if request.response_format: - # Groq's JSON mode is beta at the time of writing - warnings.warn("response_format is not supported yet") - - if request.sampling_params.repetition_penalty != 1.0: - # groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty - # seem to have different semantics - # frequency_penalty defaults to 0 is a float between -2.0 and 2.0 - # repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0 - # so we exclude it for now - warnings.warn("repetition_penalty is not supported") - - if request.tool_config.tool_prompt_format != ToolPromptFormat.json: - warnings.warn("tool_prompt_format is not used by Groq. Ignoring.") - - sampling_options = get_sampling_strategy_options(request.sampling_params) - return CompletionCreateParams( - model=request.model, - messages=[_convert_message(message) for message in request.messages], - logprobs=None, - frequency_penalty=None, - stream=request.stream, - max_tokens=request.sampling_params.max_tokens or None, - temperature=sampling_options.get("temperature", 1.0), - top_p=sampling_options.get("top_p", 1.0), - tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []], - tool_choice=(request.tool_config.tool_choice.value if request.tool_config.tool_choice else None), - ) - - -def _convert_message(message: Message) -> ChatCompletionMessageParam: - if message.role == "system": - return ChatCompletionSystemMessageParam(role="system", content=message.content) - elif message.role == "user": - return ChatCompletionUserMessageParam(role="user", content=message.content) - elif message.role == "assistant": - return ChatCompletionAssistantMessageParam(role="assistant", content=message.content) - else: - raise ValueError(f"Invalid message role: {message.role}") - - -def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict: - # Groq requires a description for function tools - if tool_definition.description is None: - raise AssertionError("tool_definition.description is required") - - tool_parameters = tool_definition.parameters or {} - return ChatCompletionToolParam( - type="function", - function=FunctionDefinition( - name=tool_definition.tool_name, - description=tool_definition.description, - parameters={key: _convert_groq_tool_parameter(param) for key, param in tool_parameters.items()}, - ), - ) - - -def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict: - param = { - "type": tool_parameter.param_type, - } - if tool_parameter.description is not None: - param["description"] = tool_parameter.description - if tool_parameter.required is not None: - param["required"] = tool_parameter.required - if tool_parameter.default is not None: - param["default"] = tool_parameter.default - return param - - -def convert_chat_completion_response( - response: ChatCompletion, -) -> ChatCompletionResponse: - # groq only supports n=1 at time of writing, so there is only one choice - choice = response.choices[0] - if choice.finish_reason == "tool_calls": - tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] - if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): - # If we couldn't parse a tool call, jsonify the tool calls and return them - return ChatCompletionResponse( - completion_message=CompletionMessage( - stop_reason=StopReason.end_of_message, - content=json.dumps(tool_calls, default=lambda x: x.model_dump()), - ), - logprobs=None, - ) - else: - # Otherwise, return tool calls as normal - return ChatCompletionResponse( - completion_message=CompletionMessage( - tool_calls=tool_calls, - stop_reason=StopReason.end_of_message, - # Content is not optional - content="", - ), - logprobs=None, - ) - else: - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content, - stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), - ), - ) - - -def _map_finish_reason_to_stop_reason( - finish_reason: Literal["stop", "length", "tool_calls"], -) -> StopReason: - """ - Convert a Groq chat completion finish_reason to a StopReason. - - finish_reason: Literal["stop", "length", "tool_calls"] - - stop -> model hit a natural stop point or a provided stop sequence - - length -> maximum number of tokens specified in the request was reached - - tool_calls -> model called a tool - """ - if finish_reason == "stop": - return StopReason.end_of_turn - elif finish_reason == "length": - return StopReason.out_of_tokens - elif finish_reason == "tool_calls": - return StopReason.end_of_message - else: - raise ValueError(f"Invalid finish reason: {finish_reason}") - - -async def convert_chat_completion_response_stream( - stream: Stream[ChatCompletionChunk], -) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - event_type = ChatCompletionResponseEventType.start - for chunk in stream: - choice = chunk.choices[0] - - if choice.finish_reason: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=choice.delta.content or ""), - logprobs=None, - stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), - ) - ) - elif choice.delta.tool_calls: - # We assume there is only one tool call per chunk, but emit a warning in case we're wrong - if len(choice.delta.tool_calls) > 1: - warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.") - - # We assume Groq produces fully formed tool calls for each chunk - tool_call = convert_tool_call(choice.delta.tool_calls[0]) - if isinstance(tool_call, ToolCall): - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), - ) - ) - else: - # Otherwise it's an UnparseableToolCall - return the raw tool call - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=tool_call.model_dump_json(), - parse_status=ToolCallParseStatus.failed, - ), - ) - ) - else: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=TextDelta(text=choice.delta.content or ""), - logprobs=None, - ) - ) - event_type = ChatCompletionResponseEventType.progress diff --git a/llama_stack/providers/remote/inference/groq/models.py b/llama_stack/providers/remote/inference/groq/models.py index 54ca2e839..4364edffa 100644 --- a/llama_stack/providers/remote/inference/groq/models.py +++ b/llama_stack/providers/remote/inference/groq/models.py @@ -7,21 +7,21 @@ from llama_stack.models.llama.sku_list import CoreModelId from llama_stack.providers.utils.inference.model_registry import build_model_entry -_MODEL_ENTRIES = [ +MODEL_ENTRIES = [ build_model_entry( - "llama3-8b-8192", + "groq/llama3-8b-8192", CoreModelId.llama3_1_8b_instruct.value, ), build_model_entry( - "llama-3.1-8b-instant", + "groq/llama-3.1-8b-instant", CoreModelId.llama3_1_8b_instruct.value, ), build_model_entry( - "llama3-70b-8192", + "groq/llama3-70b-8192", CoreModelId.llama3_70b_instruct.value, ), build_model_entry( - "llama-3.3-70b-versatile", + "groq/llama-3.3-70b-versatile", CoreModelId.llama3_3_70b_instruct.value, ), # Groq only contains a preview version for llama-3.2-3b @@ -29,7 +29,7 @@ _MODEL_ENTRIES = [ # to pass the test fixture # TODO(aidand): Replace this with a stable model once Groq supports it build_model_entry( - "llama-3.2-3b-preview", + "groq/llama-3.2-3b-preview", CoreModelId.llama3_2_3b_instruct.value, ), ] diff --git a/llama_stack/providers/remote/inference/openai/config.py b/llama_stack/providers/remote/inference/openai/config.py index 07f96a3df..2b0cc2c10 100644 --- a/llama_stack/providers/remote/inference/openai/config.py +++ b/llama_stack/providers/remote/inference/openai/config.py @@ -11,6 +11,13 @@ from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type +class OpenAIProviderDataValidator(BaseModel): + openai_api_key: Optional[str] = Field( + default=None, + description="API key for OpenAI models", + ) + + @json_schema_type class OpenAIConfig(BaseModel): api_key: Optional[str] = Field( diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 80ab2943f..6b9c02e6c 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): def __init__(self, config: OpenAIConfig) -> None: - LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES) + LiteLLMOpenAIMixin.__init__( + self, + MODEL_ENTRIES, + api_key_from_config=config.api_key, + provider_data_api_key_field="openai_api_key", + ) self.config = config async def initialize(self) -> None: - pass + await super().initialize() async def shutdown(self) -> None: - pass + await super().shutdown() diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py deleted file mode 100644 index 34725e957..000000000 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ /dev/null @@ -1,575 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json - -import pytest -from groq.types.chat.chat_completion import ChatCompletion, Choice -from groq.types.chat.chat_completion_chunk import ( - ChatCompletionChunk, - ChoiceDelta, - ChoiceDeltaToolCall, - ChoiceDeltaToolCallFunction, -) -from groq.types.chat.chat_completion_chunk import ( - Choice as StreamChoice, -) -from groq.types.chat.chat_completion_message import ChatCompletionMessage -from groq.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, - Function, -) -from groq.types.shared.function_definition import FunctionDefinition - -from llama_stack.apis.common.content_types import ToolCallParseStatus -from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponseEventType, - CompletionMessage, - StopReason, - SystemMessage, - ToolCall, - ToolChoice, - ToolDefinition, - UserMessage, -) -from llama_stack.models.llama.datatypes import GreedySamplingStrategy, ToolParamDefinition, TopPSamplingStrategy -from llama_stack.providers.remote.inference.groq.groq_utils import ( - convert_chat_completion_request, - convert_chat_completion_response, - convert_chat_completion_response_stream, -) - - -class TestConvertChatCompletionRequest: - def test_sets_model(self): - request = self._dummy_chat_completion_request() - request.model = "Llama-3.2-3B" - - converted = convert_chat_completion_request(request) - - assert converted["model"] == "Llama-3.2-3B" - - def test_converts_user_message(self): - request = self._dummy_chat_completion_request() - request.messages = [UserMessage(content="Hello World")] - - converted = convert_chat_completion_request(request) - - assert converted["messages"] == [ - {"role": "user", "content": "Hello World"}, - ] - - def test_converts_system_message(self): - request = self._dummy_chat_completion_request() - request.messages = [SystemMessage(content="You are a helpful assistant.")] - - converted = convert_chat_completion_request(request) - - assert converted["messages"] == [ - {"role": "system", "content": "You are a helpful assistant."}, - ] - - def test_converts_completion_message(self): - request = self._dummy_chat_completion_request() - request.messages = [ - UserMessage(content="Hello World"), - CompletionMessage( - content="Hello World! How can I help you today?", - stop_reason=StopReason.end_of_message, - ), - ] - - converted = convert_chat_completion_request(request) - - assert converted["messages"] == [ - {"role": "user", "content": "Hello World"}, - {"role": "assistant", "content": "Hello World! How can I help you today?"}, - ] - - def test_does_not_include_logprobs(self): - request = self._dummy_chat_completion_request() - request.logprobs = True - - with pytest.warns(Warning) as warnings: - converted = convert_chat_completion_request(request) - - assert "logprobs are not supported yet" in warnings[0].message.args[0] - assert converted.get("logprobs") is None - - def test_does_not_include_response_format(self): - request = self._dummy_chat_completion_request() - request.response_format = { - "type": "json_object", - "json_schema": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "number"}, - }, - }, - } - - with pytest.warns(Warning) as warnings: - converted = convert_chat_completion_request(request) - - assert "response_format is not supported yet" in warnings[0].message.args[0] - assert converted.get("response_format") is None - - def test_does_not_include_repetition_penalty(self): - request = self._dummy_chat_completion_request() - request.sampling_params.repetition_penalty = 1.5 - - with pytest.warns(Warning) as warnings: - converted = convert_chat_completion_request(request) - - assert "repetition_penalty is not supported" in warnings[0].message.args[0] - assert converted.get("repetition_penalty") is None - assert converted.get("frequency_penalty") is None - - def test_includes_stream(self): - request = self._dummy_chat_completion_request() - request.stream = True - - converted = convert_chat_completion_request(request) - - assert converted["stream"] is True - - def test_if_max_tokens_is_0_then_it_is_not_included(self): - request = self._dummy_chat_completion_request() - # 0 is the default value for max_tokens - # So we assume that if it's 0, the user didn't set it - request.sampling_params.max_tokens = 0 - - converted = convert_chat_completion_request(request) - - assert converted.get("max_tokens") is None - - def test_includes_max_tokens_if_set(self): - request = self._dummy_chat_completion_request() - request.sampling_params.max_tokens = 100 - - converted = convert_chat_completion_request(request) - - assert converted["max_tokens"] == 100 - - def _dummy_chat_completion_request(self): - return ChatCompletionRequest( - model="Llama-3.2-3B", - messages=[UserMessage(content="Hello World")], - ) - - def test_includes_stratgy(self): - request = self._dummy_chat_completion_request() - request.sampling_params.strategy = TopPSamplingStrategy(temperature=0.5, top_p=0.95) - - converted = convert_chat_completion_request(request) - - assert converted["temperature"] == 0.5 - assert converted["top_p"] == 0.95 - - def test_includes_greedy_strategy(self): - request = self._dummy_chat_completion_request() - request.sampling_params.strategy = GreedySamplingStrategy() - - converted = convert_chat_completion_request(request) - - assert converted["temperature"] == 0.0 - - def test_includes_tool_choice(self): - request = self._dummy_chat_completion_request() - request.tool_config.tool_choice = ToolChoice.required - - converted = convert_chat_completion_request(request) - - assert converted["tool_choice"] == "required" - - def test_includes_tools(self): - request = self._dummy_chat_completion_request() - request.tools = [ - ToolDefinition( - tool_name="get_flight_info", - description="Get fight information between two destinations.", - parameters={ - "origin": ToolParamDefinition( - param_type="string", - description="The origin airport code. E.g., AU", - required=True, - ), - "destination": ToolParamDefinition( - param_type="string", - description="The destination airport code. E.g., 'LAX'", - required=True, - ), - "passengers": ToolParamDefinition( - param_type="array", - description="The passengers", - required=False, - ), - }, - ), - ToolDefinition( - tool_name="log", - description="Calulate the logarithm of a number", - parameters={ - "number": ToolParamDefinition( - param_type="float", - description="The number to calculate the logarithm of", - required=True, - ), - "base": ToolParamDefinition( - param_type="integer", - description="The base of the logarithm", - required=False, - default=10, - ), - }, - ), - ] - - converted = convert_chat_completion_request(request) - - assert converted["tools"] == [ - { - "type": "function", - "function": FunctionDefinition( - name="get_flight_info", - description="Get fight information between two destinations.", - parameters={ - "origin": { - "type": "string", - "description": "The origin airport code. E.g., AU", - "required": True, - }, - "destination": { - "type": "string", - "description": "The destination airport code. E.g., 'LAX'", - "required": True, - }, - "passengers": { - "type": "array", - "description": "The passengers", - "required": False, - }, - }, - ), - }, - { - "type": "function", - "function": FunctionDefinition( - name="log", - description="Calulate the logarithm of a number", - parameters={ - "number": { - "type": "float", - "description": "The number to calculate the logarithm of", - "required": True, - }, - "base": { - "type": "integer", - "description": "The base of the logarithm", - "required": False, - "default": 10, - }, - }, - ), - }, - ] - - -class TestConvertNonStreamChatCompletionResponse: - def test_returns_response(self): - response = self._dummy_chat_completion_response() - response.choices[0].message.content = "Hello World" - - converted = convert_chat_completion_response(response) - - assert converted.completion_message.content == "Hello World" - - def test_maps_stop_to_end_of_message(self): - response = self._dummy_chat_completion_response() - response.choices[0].finish_reason = "stop" - - converted = convert_chat_completion_response(response) - - assert converted.completion_message.stop_reason == StopReason.end_of_turn - - def test_maps_length_to_end_of_message(self): - response = self._dummy_chat_completion_response() - response.choices[0].finish_reason = "length" - - converted = convert_chat_completion_response(response) - - assert converted.completion_message.stop_reason == StopReason.out_of_tokens - - def test_maps_tool_call_to_end_of_message(self): - response = self._dummy_chat_completion_response_with_tool_call() - - converted = convert_chat_completion_response(response) - - assert converted.completion_message.stop_reason == StopReason.end_of_message - - def test_converts_multiple_tool_calls(self): - response = self._dummy_chat_completion_response_with_tool_call() - response.choices[0].message.tool_calls = [ - ChatCompletionMessageToolCall( - id="tool_call_id", - type="function", - function=Function( - name="get_flight_info", - arguments='{"origin": "AU", "destination": "LAX"}', - ), - ), - ChatCompletionMessageToolCall( - id="tool_call_id_2", - type="function", - function=Function( - name="log", - arguments='{"number": 10, "base": 2}', - ), - ), - ] - - converted = convert_chat_completion_response(response) - - assert converted.completion_message.tool_calls == [ - ToolCall( - call_id="tool_call_id", - tool_name="get_flight_info", - arguments={"origin": "AU", "destination": "LAX"}, - ), - ToolCall( - call_id="tool_call_id_2", - tool_name="log", - arguments={"number": 10, "base": 2}, - ), - ] - - def test_converts_unparseable_tool_calls(self): - response = self._dummy_chat_completion_response_with_tool_call() - response.choices[0].message.tool_calls = [ - ChatCompletionMessageToolCall( - id="tool_call_id", - type="function", - function=Function( - name="log", - arguments="(number=10, base=2)", - ), - ), - ] - - converted = convert_chat_completion_response(response) - - assert ( - converted.completion_message.content - == '[{"call_id": "tool_call_id", "tool_name": "log", "arguments": "(number=10, base=2)"}]' - ) - - def _dummy_chat_completion_response(self): - return ChatCompletion( - id="chatcmpl-123", - model="Llama-3.2-3B", - choices=[ - Choice( - index=0, - message=ChatCompletionMessage(role="assistant", content="Hello World"), - finish_reason="stop", - ) - ], - created=1729382400, - object="chat.completion", - ) - - def _dummy_chat_completion_response_with_tool_call(self): - return ChatCompletion( - id="chatcmpl-123", - model="Llama-3.2-3B", - choices=[ - Choice( - index=0, - message=ChatCompletionMessage( - role="assistant", - tool_calls=[ - ChatCompletionMessageToolCall( - id="tool_call_id", - type="function", - function=Function( - name="get_flight_info", - arguments='{"origin": "AU", "destination": "LAX"}', - ), - ) - ], - ), - finish_reason="tool_calls", - ) - ], - created=1729382400, - object="chat.completion", - ) - - -class TestConvertStreamChatCompletionResponse: - @pytest.mark.asyncio - async def test_returns_stream(self): - def chat_completion_stream(): - messages = ["Hello ", "World ", " !"] - for i, message in enumerate(messages): - chunk = self._dummy_chat_completion_chunk() - chunk.choices[0].delta.content = message - yield chunk - - chunk = self._dummy_chat_completion_chunk() - chunk.choices[0].delta.content = None - chunk.choices[0].finish_reason = "stop" - yield chunk - - stream = chat_completion_stream() - converted = convert_chat_completion_response_stream(stream) - - iter = converted.__aiter__() - chunk = await iter.__anext__() - assert chunk.event.event_type == ChatCompletionResponseEventType.start - assert chunk.event.delta.text == "Hello " - - chunk = await iter.__anext__() - assert chunk.event.event_type == ChatCompletionResponseEventType.progress - assert chunk.event.delta.text == "World " - - chunk = await iter.__anext__() - assert chunk.event.event_type == ChatCompletionResponseEventType.progress - assert chunk.event.delta.text == " !" - - chunk = await iter.__anext__() - assert chunk.event.event_type == ChatCompletionResponseEventType.complete - assert chunk.event.delta.text == "" - assert chunk.event.stop_reason == StopReason.end_of_turn - - with pytest.raises(StopAsyncIteration): - await iter.__anext__() - - @pytest.mark.asyncio - async def test_returns_tool_calls_stream(self): - def tool_call_stream(): - tool_calls = [ - ToolCall( - call_id="tool_call_id", - tool_name="get_flight_info", - arguments={"origin": "AU", "destination": "LAX"}, - ), - ToolCall( - call_id="tool_call_id_2", - tool_name="log", - arguments={"number": 10, "base": 2}, - ), - ] - for i, tool_call in enumerate(tool_calls): - chunk = self._dummy_chat_completion_chunk_with_tool_call() - chunk.choices[0].delta.tool_calls = [ - ChoiceDeltaToolCall( - index=0, - type="function", - id=tool_call.call_id, - function=ChoiceDeltaToolCallFunction( - name=tool_call.tool_name, - arguments=json.dumps(tool_call.arguments), - ), - ), - ] - yield chunk - - chunk = self._dummy_chat_completion_chunk_with_tool_call() - chunk.choices[0].delta.content = None - chunk.choices[0].finish_reason = "stop" - yield chunk - - stream = tool_call_stream() - converted = convert_chat_completion_response_stream(stream) - - iter = converted.__aiter__() - chunk = await iter.__anext__() - assert chunk.event.event_type == ChatCompletionResponseEventType.start - assert chunk.event.delta.tool_call == ToolCall( - call_id="tool_call_id", - tool_name="get_flight_info", - arguments={"origin": "AU", "destination": "LAX"}, - ) - - @pytest.mark.asyncio - async def test_returns_tool_calls_stream_with_unparseable_tool_calls(self): - def tool_call_stream(): - chunk = self._dummy_chat_completion_chunk_with_tool_call() - chunk.choices[0].delta.tool_calls = [ - ChoiceDeltaToolCall( - index=0, - type="function", - id="tool_call_id", - function=ChoiceDeltaToolCallFunction( - name="get_flight_info", - arguments="(origin=AU, destination=LAX)", - ), - ), - ] - yield chunk - - chunk = self._dummy_chat_completion_chunk_with_tool_call() - chunk.choices[0].delta.content = None - chunk.choices[0].finish_reason = "stop" - yield chunk - - stream = tool_call_stream() - converted = convert_chat_completion_response_stream(stream) - - iter = converted.__aiter__() - chunk = await iter.__anext__() - assert chunk.event.event_type == ChatCompletionResponseEventType.start - assert ( - chunk.event.delta.content - == '{"call_id":"tool_call_id","tool_name":"get_flight_info","arguments":"(origin=AU, destination=LAX)"}' - ) - assert chunk.event.delta.parse_status == ToolCallParseStatus.failed - - def _dummy_chat_completion_chunk(self): - return ChatCompletionChunk( - id="chatcmpl-123", - model="Llama-3.2-3B", - choices=[ - StreamChoice( - index=0, - delta=ChoiceDelta(role="assistant", content="Hello World"), - ) - ], - created=1729382400, - object="chat.completion.chunk", - x_groq=None, - ) - - def _dummy_chat_completion_chunk_with_tool_call(self): - return ChatCompletionChunk( - id="chatcmpl-123", - model="Llama-3.2-3B", - choices=[ - StreamChoice( - index=0, - delta=ChoiceDelta( - role="assistant", - content="Hello World", - tool_calls=[ - ChoiceDeltaToolCall( - index=0, - type="function", - function=ChoiceDeltaToolCallFunction( - name="get_flight_info", - arguments='{"origin": "AU", "destination": "LAX"}', - ), - ) - ], - ), - ) - ], - created=1729382400, - object="chat.completion.chunk", - x_groq=None, - ) diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index a916e4f99..ecb6961da 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -31,6 +31,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models.models import Model +from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -49,10 +50,18 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( class LiteLLMOpenAIMixin( ModelRegistryHelper, Inference, + NeedsRequestProviderData, ): - def __init__(self, model_entries) -> None: - self.model_entries = model_entries + def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str): ModelRegistryHelper.__init__(self, model_entries) + self.api_key_from_config = api_key_from_config + self.provider_data_api_key_field = provider_data_api_key_field + + async def initialize(self): + pass + + async def shutdown(self): + pass async def register_model(self, model: Model) -> Model: model_id = self.get_provider_model_id(model.provider_resource_id) @@ -144,8 +153,16 @@ class LiteLLMOpenAIMixin( if request.tool_config.tool_choice: input_dict["tool_choice"] = request.tool_config.tool_choice.value + provider_data = self.get_request_provider_data() + key_field = self.provider_data_api_key_field + if provider_data and getattr(provider_data, key_field, None): + api_key = getattr(provider_data, key_field) + else: + api_key = self.api_key_from_config + return { "model": request.model, + "api_key": api_key, **input_dict, "stream": request.stream, **get_sampling_options(request.sampling_params), diff --git a/llama_stack/templates/dev/build.yaml b/llama_stack/templates/dev/build.yaml index 96f588e8d..726ebccca 100644 --- a/llama_stack/templates/dev/build.yaml +++ b/llama_stack/templates/dev/build.yaml @@ -7,6 +7,7 @@ distribution_spec: - remote::fireworks - remote::anthropic - remote::gemini + - remote::groq - inline::sentence-transformers vector_io: - inline::sqlite-vec diff --git a/llama_stack/templates/dev/dev.py b/llama_stack/templates/dev/dev.py index 7b449a0b4..fe80c3842 100644 --- a/llama_stack/templates/dev/dev.py +++ b/llama_stack/templates/dev/dev.py @@ -24,6 +24,8 @@ from llama_stack.providers.remote.inference.fireworks.config import FireworksImp from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES from llama_stack.providers.remote.inference.gemini.config import GeminiConfig from llama_stack.providers.remote.inference.gemini.models import MODEL_ENTRIES as GEMINI_MODEL_ENTRIES +from llama_stack.providers.remote.inference.groq.config import GroqConfig +from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES as GROQ_MODEL_ENTRIES from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -52,6 +54,11 @@ def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]: GEMINI_MODEL_ENTRIES, GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"), ), + ( + "groq", + GROQ_MODEL_ENTRIES, + GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"), + ), ] inference_providers = [] default_models = [] @@ -78,14 +85,9 @@ def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]: def get_distribution_template() -> DistributionTemplate: + inference_providers, default_models = get_inference_providers() providers = { - "inference": [ - "remote::openai", - "remote::fireworks", - "remote::anthropic", - "remote::gemini", - "inline::sentence-transformers", - ], + "inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]), "vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], @@ -136,7 +138,6 @@ def get_distribution_template() -> DistributionTemplate: "embedding_dimension": 384, }, ) - inference_providers, default_models = get_inference_providers() return DistributionTemplate( name=name, diff --git a/llama_stack/templates/dev/run.yaml b/llama_stack/templates/dev/run.yaml index 448a3aec7..0ada465e4 100644 --- a/llama_stack/templates/dev/run.yaml +++ b/llama_stack/templates/dev/run.yaml @@ -29,6 +29,11 @@ providers: provider_type: remote::gemini config: api_key: ${env.GEMINI_API_KEY:} + - provider_id: groq + provider_type: remote::groq + config: + url: https://api.groq.com + api_key: ${env.GROQ_API_KEY:} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} @@ -241,6 +246,31 @@ models: provider_id: gemini provider_model_id: gemini/text-embedding-004 model_type: embedding +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: groq + provider_model_id: groq/llama3-8b-8192 + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: groq + provider_model_id: groq/llama-3.1-8b-instant + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3-70B-Instruct + provider_id: groq + provider_model_id: groq/llama3-70b-8192 + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: groq + provider_model_id: groq/llama-3.3-70b-versatile + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: groq + provider_model_id: groq/llama-3.2-3b-preview + model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 diff --git a/llama_stack/templates/groq/groq.py b/llama_stack/templates/groq/groq.py index 9e25f02cb..b0c7a3804 100644 --- a/llama_stack/templates/groq/groq.py +++ b/llama_stack/templates/groq/groq.py @@ -16,9 +16,8 @@ from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.groq import GroqConfig -from llama_stack.providers.remote.inference.groq.models import _MODEL_ENTRIES +from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -52,11 +51,6 @@ def get_distribution_template() -> DistributionTemplate: provider_type="inline::sentence-transformers", config=SentenceTransformersInferenceConfig.sample_run_config(), ) - vector_io_provider = Provider( - provider_id="faiss", - provider_type="inline::faiss", - config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), - ) embedding_model = ModelInput( model_id="all-MiniLM-L6-v2", provider_id="sentence-transformers", @@ -69,11 +63,13 @@ def get_distribution_template() -> DistributionTemplate: core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} default_models = [ ModelInput( - model_id=core_model_to_hf_repo[m.llama_model], + model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id, provider_model_id=m.provider_model_id, provider_id=name, + model_type=m.model_type, + metadata=m.metadata, ) - for m in _MODEL_ENTRIES + for m in MODEL_ENTRIES ] default_tool_groups = [ diff --git a/llama_stack/templates/groq/run.yaml b/llama_stack/templates/groq/run.yaml index 218514cf6..220aa847b 100644 --- a/llama_stack/templates/groq/run.yaml +++ b/llama_stack/templates/groq/run.yaml @@ -93,27 +93,27 @@ models: - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct provider_id: groq - provider_model_id: llama3-8b-8192 + provider_model_id: groq/llama3-8b-8192 model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct provider_id: groq - provider_model_id: llama-3.1-8b-instant + provider_model_id: groq/llama-3.1-8b-instant model_type: llm - metadata: {} model_id: meta-llama/Llama-3-70B-Instruct provider_id: groq - provider_model_id: llama3-70b-8192 + provider_model_id: groq/llama3-70b-8192 model_type: llm - metadata: {} model_id: meta-llama/Llama-3.3-70B-Instruct provider_id: groq - provider_model_id: llama-3.3-70b-versatile + provider_model_id: groq/llama-3.3-70b-versatile model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-3B-Instruct provider_id: groq - provider_model_id: llama-3.2-3b-preview + provider_model_id: groq/llama-3.2-3b-preview model_type: llm - metadata: embedding_dimension: 384 diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 961194a73..c0f4dca53 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -117,7 +117,9 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed assert len(providers) > 0, "No inference providers found" inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] - model_ids = [m.identifier for m in client.models.list()] + model_ids = set(m.identifier for m in client.models.list()) + model_ids.update(m.provider_resource_id for m in client.models.list()) + if text_model_id and text_model_id not in model_ids: client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) if vision_model_id and vision_model_id not in model_ids: diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 59b5bf12a..577d995ad 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -18,7 +18,7 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id): provider_id = models[model_id].provider_id providers = {p.provider_id: p for p in client_with_models.providers.list()} provider = providers[provider_id] - if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"): + if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini", "remote::groq"): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")