diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index c8d061f6c..f159c807e 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -161,4 +161,16 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="groq", + pip_packages=[ + "openai", + ], + module="llama_stack.providers.remote.inference.groq", + config_class="llama_stack.providers.remote.inference.groq.GroqImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator", + ), + ), ] diff --git a/llama_stack/providers/remote/inference/groq/__init__.py b/llama_stack/providers/remote/inference/groq/__init__.py new file mode 100644 index 000000000..1a0120983 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import GroqImplConfig + + +class GroqProviderDataValidator(BaseModel): + groq_api_key: str + + +async def get_adapter_impl(config: GroqImplConfig, _deps): + from .groq import GroqInferenceAdapter + + assert isinstance( + config, GroqImplConfig + ), f"Unexpected config type: {type(config)}" + impl = GroqInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/groq/config.py b/llama_stack/providers/remote/inference/groq/config.py new file mode 100644 index 000000000..12b4fc330 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/config.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class GroqImplConfig(BaseModel): + url: str = Field( + default="https://api.groq.com/openai/v1", + description="The URL for the Groq API server", + ) + api_key: Optional[str] = Field( + default=None, + description="The Groq API Key", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "url": "https://api.groq.com/openai/v1", + "api_key": "${env.GROQ_API_KEY}", + } diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py new file mode 100644 index 000000000..56f457af8 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -0,0 +1,464 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import AsyncGenerator, Dict, List, Optional, Union + +from llama_models.datatypes import CoreModelId +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer +from openai import AsyncOpenAI + +from llama_stack.apis.inference import ( + AsyncIterator, + ChatCompletionRequest, + CompletionRequest, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + InterleavedTextMedia, + LogProbConfig, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.utils.inference.model_registry import ( + ModelRegistryHelper, + build_model_alias, +) +from llama_stack.providers.utils.inference.openai_compat import ( + ChatCompletionResponseStreamChunk, + OpenAICompatCompletionChoice, + OpenAICompatCompletionResponse, + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, +) + +from .config import GroqImplConfig + + +class GroqErrorCode(str, Enum): + INVALID_AUTH = "invalid_authentication" + RATE_LIMIT_EXCEEDED = "rate_limit_exceeded" + QUOTA_EXCEEDED = "quota_exceeded" + CONTEXT_LENGTH_EXCEEDED = "context_length_exceeded" + INVALID_REQUEST = "invalid_request" + MODEL_NOT_FOUND = "model_not_found" + + +MODEL_ALIASES = [ + build_model_alias( + "llama-3.2-1b-preview", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_model_alias( + "llama-3.2-3b-preview", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_model_alias( + "llama-3.2-11b-vision-preview", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_model_alias( + "llama-3.2-90b-vision-preview", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), + build_model_alias( + "llama-3.1-8b-instant", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "llama-3.1-70b-versatile", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_alias( + "llama3-8b-8192", + CoreModelId.llama3_8b_instruct.value, + ), + build_model_alias( + "llama3-70b-8192", + CoreModelId.llama3_70b_instruct.value, + ), + build_model_alias( + "llama3-groq-8b-8192-tool-use-preview", + CoreModelId.llama3_8b_instruct.value, + ), + build_model_alias( + "llama3-groq-70b-8192-tool-use-preview", + CoreModelId.llama3_70b_instruct.value, + ), + build_model_alias( + "llama-guard-3-8b", + CoreModelId.llama_guard_3_8b.value, + ), +] + +UNSUPPORTED_PARAMS = { + "logprobs", + "top_logprobs", + "response_format", +} + + +class GroqInferenceAdapter( + ModelRegistryHelper, Inference, NeedsRequestProviderData +): + """Groq inference adapter using OpenAI client compatibility layer. + + This adapter provides access to Groq's AI models through their OpenAI-compatible API. + It handles authentication, request formatting, and response processing while managing + unsupported features gracefully. + + Supports tool/function calling for compatible models. + """ + + def __init__(self, config: GroqImplConfig) -> None: + """Initialize the Groq inference adapter. + + Args: + config: Configuration for the Groq implementation + """ + ModelRegistryHelper.__init__(self, MODEL_ALIASES) + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + self._client: Optional[AsyncOpenAI] = None + + @property + def client(self) -> AsyncOpenAI: + """Get or create the OpenAI client instance. + + Returns: + AsyncOpenAI: The configured client instance + """ + if self._client is None: + self._client = AsyncOpenAI( + base_url=self.config.url, + api_key=self._get_api_key(), + timeout=60.0, + ) + return self._client + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + def _get_api_key(self) -> str: + """Get the API key from config or request headers. + + Returns: + str: The API key to use + + Raises: + ValueError: If no API key is available + """ + if self.config.api_key is not None: + return self.config.api_key + + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.groq_api_key: + raise ValueError( + 'Pass Groq API Key in the header X-LlamaStack-ProviderData as { "groq_api_key": }' + ) + return provider_data.groq_api_key + + def _filter_unsupported_params(self, params: Dict) -> Dict: + """Remove parameters not supported by Groq API. + + Args: + params: Original parameters dictionary + + Returns: + Dict: Filtered parameters dictionary + """ + return {k: v for k, v in params.items() if k not in UNSUPPORTED_PARAMS} + + def _convert_tool_to_function(self, tool: ToolDefinition) -> dict: + """Convert a ToolDefinition to Groq function format. + + Args: + tool: Tool definition to convert + + Returns: + dict: Function definition in Groq format + """ + return { + "type": "function", + "function": { + "name": tool.tool_name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": { + name: { + "type": param.param_type, + "description": param.description, + } + for name, param in tool.parameters.items() + }, + "required": list(tool.parameters.keys()), + }, + }, + } + + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + """Prepare parameters for the API request. + + Args: + request: The completion request + + Returns: + dict: Prepared parameters for the API call + """ + sampling_options = get_sampling_options(request.sampling_params) + filtered_options = self._filter_unsupported_params(sampling_options) + + if "temperature" in filtered_options: + filtered_options["temperature"] = min( + max(filtered_options["temperature"], 0), 2 + ) + + input_dict = {} + if isinstance(request, ChatCompletionRequest): + input_dict["messages"] = [ + {"role": message.role, "content": message.content} + for message in request.messages + ] + + if request.tools: + input_dict["tools"] = [ + self._convert_tool_to_function(tool) + for tool in request.tools + ] + + if request.tool_choice == ToolChoice.auto: + input_dict["tool_choice"] = "auto" + elif request.tool_choice == ToolChoice.required: + input_dict["tool_choice"] = "required" + elif isinstance(request.tool_choice, str): + input_dict["tool_choice"] = { + "type": "function", + "function": {"name": request.tool_choice}, + } + else: + input_dict["tool_choice"] = "none" + + else: + input_dict["prompt"] = request.content + + return { + "model": request.model, + **input_dict, + **filtered_options, + "stream": request.stream, + } + + async def _handle_groq_error(self, e: Exception) -> None: + """Handle Groq specific API errors with detailed messages. + + Args: + e: The exception to handle + + Raises: + ValueError: For client errors + RuntimeError: For server errors + """ + error_msg = str(e) + error_data = {} + + try: + if hasattr(e, "response"): + error_data = e.response.json().get("error", {}) + except Exception: + pass + + error_code = error_data.get("code", "") + error_message = error_data.get("message", error_msg) + + if "401" in error_msg or error_code == GroqErrorCode.INVALID_AUTH: + raise ValueError("Invalid API key or unauthorized access") from e + + elif ( + "429" in error_msg + or error_code == GroqErrorCode.RATE_LIMIT_EXCEEDED + ): + raise ValueError( + "Rate limit exceeded. Please try again later" + ) from e + + elif error_code == GroqErrorCode.QUOTA_EXCEEDED: + raise ValueError( + "API quota exceeded. Please check your usage limits" + ) from e + + elif error_code == GroqErrorCode.CONTEXT_LENGTH_EXCEEDED: + raise ValueError( + "Total number of input and output tokens exceeds model's context length" + ) from e + + elif error_code == GroqErrorCode.INVALID_REQUEST: + raise ValueError(f"Invalid request: {error_message}") from e + + elif error_code == GroqErrorCode.MODEL_NOT_FOUND: + raise ValueError(f"Model not found: {error_message}") from e + + raise RuntimeError(f"Groq API error: {error_message}") from e + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + try: + params = await self._get_params(request) + response = await self.client.chat.completions.create(**params) + + if response.choices[0].message.tool_calls: + tool_call = response.choices[0].message.tool_calls[0] + choice = OpenAICompatCompletionChoice( + finish_reason=response.choices[0].finish_reason, + text="", + tool_calls=[ + { + "tool_name": tool_call.function.name, + "arguments": tool_call.function.arguments or "", + } + ], + ) + else: + choice = OpenAICompatCompletionChoice( + finish_reason=response.choices[0].finish_reason, + text=response.choices[0].message.content or "", + tool_calls=[], + ) + + compat_response = OpenAICompatCompletionResponse(choices=[choice]) + return process_chat_completion_response( + compat_response, self.formatter + ) + + except Exception as e: + await self._handle_groq_error(e) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncIterator[ChatCompletionResponseStreamChunk]: + try: + params = await self._get_params(request) + stream = await self.client.chat.completions.create(**params) + + async def _to_async_generator(): + async for chunk in stream: + if chunk.choices[0].delta.tool_calls: + tool_call = chunk.choices[0].delta.tool_calls[0] + choice = OpenAICompatCompletionChoice( + finish_reason=chunk.choices[0].finish_reason, + text="", + tool_calls=[ + { + "tool_name": tool_call.function.name + if tool_call.function + else None, + "arguments": tool_call.function.arguments + if tool_call.function + else "", + } + ] + if tool_call.function + else None, + ) + else: + choice = OpenAICompatCompletionChoice( + finish_reason=chunk.choices[0].finish_reason, + text=chunk.choices[0].delta.content or "", + tool_calls=[], + ) + yield OpenAICompatCompletionResponse(choices=[choice]) + + async for chunk in process_chat_completion_stream_response( + _to_async_generator(), self.formatter + ): + yield chunk + + except Exception as e: + await self._handle_groq_error(e) + + def completion( + self, + model_id: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[ + CompletionResponse, AsyncIterator[CompletionResponseStreamChunk] + ]: + raise NotImplementedError("Groq does not support text completion") + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + """Handle chat completion requests. + + Args: + model_id: The model identifier + messages: List of chat messages + sampling_params: Parameters for text generation + tools: Tool definitions for function calling + tool_choice: Tool choice option + tool_prompt_format: Tool prompt format + response_format: Response format (not supported) + stream: Whether to stream the response + logprobs: Log probability config (not supported) + + Returns: + AsyncGenerator: The completion response + """ + model = await self.model_store.get_model(model_id) + request = ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + """Embeddings are not supported. + + Raises: + NotImplementedError: Always raised as this feature is not supported + """ + raise NotImplementedError("Embeddings not supported for Groq") diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 2007818e5..bf30079b6 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -22,6 +22,7 @@ from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig +from llama_stack.providers.remote.inference.groq import GroqImplConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -156,6 +157,22 @@ def inference_nvidia() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_groq() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="groq", + provider_type="remote::groq", + config=GroqImplConfig().model_dump(), + ) + ], + provider_data=dict( + groq_api_key=get_env_or_fail("GROQ_API_KEY"), + ), + ) + + def get_model_short_name(model_name: str) -> str: """Convert model name to a short test identifier. @@ -190,6 +207,7 @@ INFERENCE_FIXTURES = [ "remote", "bedrock", "nvidia", + "groq", ] diff --git a/llama_stack/templates/groq/__init__.py b/llama_stack/templates/groq/__init__.py new file mode 100644 index 000000000..02a39601d --- /dev/null +++ b/llama_stack/templates/groq/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .groq import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/groq/build.yaml b/llama_stack/templates/groq/build.yaml new file mode 100644 index 000000000..afa67b724 --- /dev/null +++ b/llama_stack/templates/groq/build.yaml @@ -0,0 +1,19 @@ +version: '2' +name: groq +distribution_spec: + description: Use Groq for running LLM inference + docker_image: null + providers: + inference: + - remote::groq + memory: + - inline::faiss + - remote::chromadb + - remote::pgvector + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference +image_type: conda diff --git a/llama_stack/templates/groq/doc_template.md b/llama_stack/templates/groq/doc_template.md new file mode 100644 index 000000000..a72799bc1 --- /dev/null +++ b/llama_stack/templates/groq/doc_template.md @@ -0,0 +1,66 @@ +--- +orphan: true +--- +# Groq Distribution + +```{toctree} +:maxdepth: 2 +:hidden: + +self +``` + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }}` +{% endfor %} +{% endif %} + +### Prerequisite: API Keys + +Make sure you have access to a Groq API Key. You can get one by signing up at [console.groq.com](https://console.groq.com). + +## Running Llama Stack with Groq + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + llamastack/distribution-{{ name }} \ + --port $LLAMA_STACK_PORT \ + --env GROQ_API_KEY=$GROQ_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template {{ name }} --image-type conda +llama stack run ./run.yaml \ + --port $LLAMA_STACK_PORT \ + --env GROQ_API_KEY=$GROQ_API_KEY +``` diff --git a/llama_stack/templates/groq/groq.py b/llama_stack/templates/groq/groq.py new file mode 100644 index 000000000..c2b6a0f53 --- /dev/null +++ b/llama_stack/templates/groq/groq.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_models.sku_list import all_registered_models + +from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.providers.remote.inference.groq import GroqImplConfig +from llama_stack.providers.remote.inference.groq.groq import MODEL_ALIASES +from llama_stack.templates.template import ( + DistributionTemplate, + RunConfigSettings, +) + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": ["remote::groq"], + "memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + } + + inference_provider = Provider( + provider_id="groq", + provider_type="remote::groq", + config=GroqImplConfig.sample_run_config(), + ) + + core_model_to_hf_repo = { + m.descriptor(): m.huggingface_repo for m in all_registered_models() + } + + default_models = [ + ModelInput( + model_id=core_model_to_hf_repo[m.llama_model], + provider_model_id=m.provider_model_id, + ) + for m in MODEL_ALIASES + ] + + return DistributionTemplate( + name="groq", + distro_type="self_hosted", + description="Use Groq for running LLM inference", + docker_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + default_models=default_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + }, + default_models=default_models, + default_shields=[ + ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B") + ], + ), + }, + run_config_env_vars={ + "LLAMASTACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "GROQ_API_KEY": ( + "", + "Groq API Key for authentication", + ), + }, + ) diff --git a/llama_stack/templates/groq/run.yaml b/llama_stack/templates/groq/run.yaml new file mode 100644 index 000000000..43ff7ad77 --- /dev/null +++ b/llama_stack/templates/groq/run.yaml @@ -0,0 +1,95 @@ +version: '2' +image_name: groq +docker_image: null +conda_env: groq +apis: +- agents +- inference +- memory +- safety +- telemetry +providers: + inference: + - provider_id: groq + provider_type: remote::groq + config: + url: https://api.groq.com/openai/v1 + api_key: ${env.GROQ_API_KEY} + memory: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} +metadata_store: + namespace: null + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/registry.db +models: +- metadata: {} + model_id: meta-llama/Llama-3.2-1B-Instruct + provider_id: null + provider_model_id: llama-3.2-1b-preview +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: null + provider_model_id: llama-3.2-3b-preview +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: null + provider_model_id: llama-3.2-11b-vision-preview +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: null + provider_model_id: llama-3.2-90b-vision-preview +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: null + provider_model_id: llama-3.1-8b-instant +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: null + provider_model_id: llama-3.1-70b-versatile +- metadata: {} + model_id: meta-llama/Llama-3-8B-Instruct + provider_id: null + provider_model_id: llama3-8b-8192 +- metadata: {} + model_id: meta-llama/Llama-3-70B-Instruct + provider_id: null + provider_model_id: llama3-70b-8192 +- metadata: {} + model_id: meta-llama/Llama-3-8B-Instruct + provider_id: null + provider_model_id: llama3-groq-8b-8192-tool-use-preview +- metadata: {} + model_id: meta-llama/Llama-3-70B-Instruct + provider_id: null + provider_model_id: llama3-groq-70b-8192-tool-use-preview +- metadata: {} + model_id: meta-llama/Llama-Guard-3-8B + provider_id: null + provider_model_id: llama-guard-3-8b +shields: [] +memory_banks: [] +datasets: [] +scoring_fns: [] +eval_tasks: []