mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat(providers): Groq now uses LiteLLM openai-compat (#1303)
Groq has never supported raw completions anyhow. So this makes it easier to switch it to LiteLLM. All our test suite passes. I also updated all the openai-compat providers so they work with api keys passed from headers. `provider_data` ## Test Plan ```bash LLAMA_STACK_CONFIG=groq \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=groq/llama-3.3-70b-versatile --vision-inference-model="" ``` Also tested (openai, anthropic, gemini) providers. No regressions.
This commit is contained in:
parent
564f0e5f93
commit
928a39d17b
23 changed files with 165 additions and 1004 deletions
|
@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
@ -49,10 +50,18 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
class LiteLLMOpenAIMixin(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
NeedsRequestProviderData,
|
||||
):
|
||||
def __init__(self, model_entries) -> None:
|
||||
self.model_entries = model_entries
|
||||
def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str):
|
||||
ModelRegistryHelper.__init__(self, model_entries)
|
||||
self.api_key_from_config = api_key_from_config
|
||||
self.provider_data_api_key_field = provider_data_api_key_field
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
@ -144,8 +153,16 @@ class LiteLLMOpenAIMixin(
|
|||
if request.tool_config.tool_choice:
|
||||
input_dict["tool_choice"] = request.tool_config.tool_choice.value
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
key_field = self.provider_data_api_key_field
|
||||
if provider_data and getattr(provider_data, key_field, None):
|
||||
api_key = getattr(provider_data, key_field)
|
||||
else:
|
||||
api_key = self.api_key_from_config
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"api_key": api_key,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue