feat(providers): Groq now uses LiteLLM openai-compat (#1303)

Groq has never supported raw completions anyhow. So this makes it easier
to switch it to LiteLLM. All our test suite passes.

I also updated all the openai-compat providers so they work with api
keys passed from headers. `provider_data`

## Test Plan

```bash
LLAMA_STACK_CONFIG=groq \
   pytest -s -v tests/client-sdk/inference/test_text_inference.py \
   --inference-model=groq/llama-3.3-70b-versatile --vision-inference-model=""
```

Also tested (openai, anthropic, gemini) providers. No regressions.
This commit is contained in:
Ashwin Bharambe 2025-02-27 13:16:50 -08:00 committed by GitHub
parent 564f0e5f93
commit 928a39d17b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 165 additions and 1004 deletions

View file

@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -49,10 +50,18 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
class LiteLLMOpenAIMixin(
ModelRegistryHelper,
Inference,
NeedsRequestProviderData,
):
def __init__(self, model_entries) -> None:
self.model_entries = model_entries
def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str):
ModelRegistryHelper.__init__(self, model_entries)
self.api_key_from_config = api_key_from_config
self.provider_data_api_key_field = provider_data_api_key_field
async def initialize(self):
pass
async def shutdown(self):
pass
async def register_model(self, model: Model) -> Model:
model_id = self.get_provider_model_id(model.provider_resource_id)
@ -144,8 +153,16 @@ class LiteLLMOpenAIMixin(
if request.tool_config.tool_choice:
input_dict["tool_choice"] = request.tool_config.tool_choice.value
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
return {
"model": request.model,
"api_key": api_key,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),