mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat(providers): Groq now uses LiteLLM openai-compat (#1303)
Groq has never supported raw completions anyhow. So this makes it easier to switch it to LiteLLM. All our test suite passes. I also updated all the openai-compat providers so they work with api keys passed from headers. `provider_data` ## Test Plan ```bash LLAMA_STACK_CONFIG=groq \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=groq/llama-3.3-70b-versatile --vision-inference-model="" ``` Also tested (openai, anthropic, gemini) providers. No regressions.
This commit is contained in:
parent
564f0e5f93
commit
928a39d17b
23 changed files with 165 additions and 1004 deletions
|
@ -146,6 +146,7 @@
|
|||
"fastapi",
|
||||
"fire",
|
||||
"fireworks-ai",
|
||||
"groq",
|
||||
"httpx",
|
||||
"litellm",
|
||||
"matplotlib",
|
||||
|
|
|
@ -37,11 +37,11 @@ The following environment variables can be configured:
|
|||
|
||||
The following models are available by default:
|
||||
|
||||
- `meta-llama/Llama-3.1-8B-Instruct (llama3-8b-8192)`
|
||||
- `meta-llama/Llama-3.1-8B-Instruct (llama-3.1-8b-instant)`
|
||||
- `meta-llama/Llama-3-70B-Instruct (llama3-70b-8192)`
|
||||
- `meta-llama/Llama-3.3-70B-Instruct (llama-3.3-70b-versatile)`
|
||||
- `meta-llama/Llama-3.2-3B-Instruct (llama-3.2-3b-preview)`
|
||||
- `meta-llama/Llama-3.1-8B-Instruct (groq/llama3-8b-8192)`
|
||||
- `meta-llama/Llama-3.1-8B-Instruct (groq/llama-3.1-8b-instant)`
|
||||
- `meta-llama/Llama-3-70B-Instruct (groq/llama3-70b-8192)`
|
||||
- `meta-llama/Llama-3.3-70B-Instruct (groq/llama-3.3-70b-versatile)`
|
||||
- `meta-llama/Llama-3.2-3B-Instruct (groq/llama-3.2-3b-preview)`
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
|
|
@ -157,16 +157,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq",
|
||||
pip_packages=["groq"],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
@ -214,6 +204,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.openai",
|
||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -223,6 +214,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.anthropic",
|
||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -232,6 +224,17 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq",
|
||||
pip_packages=["groq"],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
|
|
@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES
|
|||
|
||||
class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: AnthropicConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
await super().shutdown()
|
||||
|
|
|
@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class AnthropicProviderDataValidator(BaseModel):
|
||||
anthropic_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
|
|
|
@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GeminiProviderDataValidator(BaseModel):
|
||||
gemini_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GeminiConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
|
|
|
@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES
|
|||
|
||||
class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: GeminiConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="gemini_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
await super().shutdown()
|
||||
|
|
|
@ -4,23 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import GroqConfig
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqInferenceAdapter
|
||||
|
||||
if not isinstance(config, GroqConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
|
||||
adapter = GroqInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
|
@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Groq models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
|
@ -25,8 +32,8 @@ class GroqConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.groq.com",
|
||||
"api_key": "${env.GROQ_API_KEY}",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
@ -4,130 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
import groq
|
||||
from groq import Groq
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from .groq_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_chat_completion_response,
|
||||
convert_chat_completion_response_stream,
|
||||
)
|
||||
from .models import _MODEL_ENTRIES
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):
|
||||
class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: GroqConfig
|
||||
|
||||
def __init__(self, config: GroqConfig):
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
self._config = config
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
# Groq doesn't support non-chat completion as of time of writing
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
model_id = self.get_provider_model_id(model_id)
|
||||
if model_id == "llama-3.2-3b-preview":
|
||||
warnings.warn(
|
||||
"Groq only contains a preview version for llama-3.2-3b-instruct. "
|
||||
"Preview models aren't recommended for production use. "
|
||||
"They can be discontinued on short notice."
|
||||
"More details: https://console.groq.com/docs/models"
|
||||
)
|
||||
|
||||
request = convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
try:
|
||||
response = self._get_client().chat.completions.create(**request)
|
||||
except groq.BadRequestError as e:
|
||||
if e.body.get("error", {}).get("code") == "tool_use_failed":
|
||||
# For smaller models, Groq may fail to call a tool even when the request is well formed
|
||||
raise ValueError("Groq failed to call a tool", e.body.get("error", {})) from e
|
||||
else:
|
||||
raise e
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
if stream:
|
||||
return convert_chat_completion_response_stream(response)
|
||||
else:
|
||||
return convert_chat_completion_response(response)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_client(self) -> Groq:
|
||||
if self._config.api_key is not None:
|
||||
return Groq(api_key=self._config.api_key)
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.groq_api_key:
|
||||
raise ValueError(
|
||||
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
|
||||
)
|
||||
return Groq(api_key=provider_data.groq_api_key)
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
|
@ -1,245 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import AsyncGenerator, Literal
|
||||
|
||||
from groq import Stream
|
||||
from groq.types.chat.chat_completion import ChatCompletion
|
||||
from groq.types.chat.chat_completion_assistant_message_param import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
)
|
||||
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||
from groq.types.chat.chat_completion_system_message_param import (
|
||||
ChatCompletionSystemMessageParam,
|
||||
)
|
||||
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
||||
from groq.types.chat.chat_completion_user_message_param import (
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from groq.types.chat.completion_create_params import CompletionCreateParams
|
||||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
Message,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
UnparseableToolCall,
|
||||
convert_tool_call,
|
||||
get_sampling_strategy_options,
|
||||
)
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
) -> CompletionCreateParams:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to a Groq API-compatible dictionary.
|
||||
Warns client if request contains unsupported features.
|
||||
"""
|
||||
|
||||
if request.logprobs:
|
||||
# Groq doesn't support logprobs at the time of writing
|
||||
warnings.warn("logprobs are not supported yet")
|
||||
|
||||
if request.response_format:
|
||||
# Groq's JSON mode is beta at the time of writing
|
||||
warnings.warn("response_format is not supported yet")
|
||||
|
||||
if request.sampling_params.repetition_penalty != 1.0:
|
||||
# groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty
|
||||
# seem to have different semantics
|
||||
# frequency_penalty defaults to 0 is a float between -2.0 and 2.0
|
||||
# repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0
|
||||
# so we exclude it for now
|
||||
warnings.warn("repetition_penalty is not supported")
|
||||
|
||||
if request.tool_config.tool_prompt_format != ToolPromptFormat.json:
|
||||
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||
|
||||
sampling_options = get_sampling_strategy_options(request.sampling_params)
|
||||
return CompletionCreateParams(
|
||||
model=request.model,
|
||||
messages=[_convert_message(message) for message in request.messages],
|
||||
logprobs=None,
|
||||
frequency_penalty=None,
|
||||
stream=request.stream,
|
||||
max_tokens=request.sampling_params.max_tokens or None,
|
||||
temperature=sampling_options.get("temperature", 1.0),
|
||||
top_p=sampling_options.get("top_p", 1.0),
|
||||
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||
tool_choice=(request.tool_config.tool_choice.value if request.tool_config.tool_choice else None),
|
||||
)
|
||||
|
||||
|
||||
def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
||||
if message.role == "system":
|
||||
return ChatCompletionSystemMessageParam(role="system", content=message.content)
|
||||
elif message.role == "user":
|
||||
return ChatCompletionUserMessageParam(role="user", content=message.content)
|
||||
elif message.role == "assistant":
|
||||
return ChatCompletionAssistantMessageParam(role="assistant", content=message.content)
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
|
||||
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
|
||||
# Groq requires a description for function tools
|
||||
if tool_definition.description is None:
|
||||
raise AssertionError("tool_definition.description is required")
|
||||
|
||||
tool_parameters = tool_definition.parameters or {}
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=tool_definition.tool_name,
|
||||
description=tool_definition.description,
|
||||
parameters={key: _convert_groq_tool_parameter(param) for key, param in tool_parameters.items()},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
|
||||
param = {
|
||||
"type": tool_parameter.param_type,
|
||||
}
|
||||
if tool_parameter.description is not None:
|
||||
param["description"] = tool_parameter.description
|
||||
if tool_parameter.required is not None:
|
||||
param["required"] = tool_parameter.required
|
||||
if tool_parameter.default is not None:
|
||||
param["default"] = tool_parameter.default
|
||||
return param
|
||||
|
||||
|
||||
def convert_chat_completion_response(
|
||||
response: ChatCompletion,
|
||||
) -> ChatCompletionResponse:
|
||||
# groq only supports n=1 at time of writing, so there is only one choice
|
||||
choice = response.choices[0]
|
||||
if choice.finish_reason == "tool_calls":
|
||||
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
stop_reason=StopReason.end_of_message,
|
||||
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
# Otherwise, return tool calls as normal
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
tool_calls=tool_calls,
|
||||
stop_reason=StopReason.end_of_message,
|
||||
# Content is not optional
|
||||
content="",
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _map_finish_reason_to_stop_reason(
|
||||
finish_reason: Literal["stop", "length", "tool_calls"],
|
||||
) -> StopReason:
|
||||
"""
|
||||
Convert a Groq chat completion finish_reason to a StopReason.
|
||||
|
||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||
- stop -> model hit a natural stop point or a provided stop sequence
|
||||
- length -> maximum number of tokens specified in the request was reached
|
||||
- tool_calls -> model called a tool
|
||||
"""
|
||||
if finish_reason == "stop":
|
||||
return StopReason.end_of_turn
|
||||
elif finish_reason == "length":
|
||||
return StopReason.out_of_tokens
|
||||
elif finish_reason == "tool_calls":
|
||||
return StopReason.end_of_message
|
||||
else:
|
||||
raise ValueError(f"Invalid finish reason: {finish_reason}")
|
||||
|
||||
|
||||
async def convert_chat_completion_response_stream(
|
||||
stream: Stream[ChatCompletionChunk],
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
event_type = ChatCompletionResponseEventType.start
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
|
||||
if choice.finish_reason:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=None,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
)
|
||||
)
|
||||
elif choice.delta.tool_calls:
|
||||
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
|
||||
if len(choice.delta.tool_calls) > 1:
|
||||
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
|
||||
|
||||
# We assume Groq produces fully formed tool calls for each chunk
|
||||
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
||||
if isinstance(tool_call, ToolCall):
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Otherwise it's an UnparseableToolCall - return the raw tool call
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call.model_dump_json(),
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=None,
|
||||
)
|
||||
)
|
||||
event_type = ChatCompletionResponseEventType.progress
|
|
@ -7,21 +7,21 @@
|
|||
from llama_stack.models.llama.sku_list import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import build_model_entry
|
||||
|
||||
_MODEL_ENTRIES = [
|
||||
MODEL_ENTRIES = [
|
||||
build_model_entry(
|
||||
"llama3-8b-8192",
|
||||
"groq/llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama-3.1-8b-instant",
|
||||
"groq/llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3-70b-8192",
|
||||
"groq/llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama-3.3-70b-versatile",
|
||||
"groq/llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# Groq only contains a preview version for llama-3.2-3b
|
||||
|
@ -29,7 +29,7 @@ _MODEL_ENTRIES = [
|
|||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_model_entry(
|
||||
"llama-3.2-3b-preview",
|
||||
"groq/llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
]
|
||||
|
|
|
@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class OpenAIProviderDataValidator(BaseModel):
|
||||
openai_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
|
|
|
@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES
|
|||
|
||||
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="openai_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
await super().shutdown()
|
||||
|
|
|
@ -1,575 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from groq.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from groq.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
ChoiceDelta,
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from groq.types.chat.chat_completion_chunk import (
|
||||
Choice as StreamChoice,
|
||||
)
|
||||
from groq.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from groq.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
|
||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import GreedySamplingStrategy, ToolParamDefinition, TopPSamplingStrategy
|
||||
from llama_stack.providers.remote.inference.groq.groq_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_chat_completion_response,
|
||||
convert_chat_completion_response_stream,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertChatCompletionRequest:
|
||||
def test_sets_model(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.model = "Llama-3.2-3B"
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["model"] == "Llama-3.2-3B"
|
||||
|
||||
def test_converts_user_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [UserMessage(content="Hello World")]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "user", "content": "Hello World"},
|
||||
]
|
||||
|
||||
def test_converts_system_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [SystemMessage(content="You are a helpful assistant.")]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
]
|
||||
|
||||
def test_converts_completion_message(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.messages = [
|
||||
UserMessage(content="Hello World"),
|
||||
CompletionMessage(
|
||||
content="Hello World! How can I help you today?",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["messages"] == [
|
||||
{"role": "user", "content": "Hello World"},
|
||||
{"role": "assistant", "content": "Hello World! How can I help you today?"},
|
||||
]
|
||||
|
||||
def test_does_not_include_logprobs(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.logprobs = True
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "logprobs are not supported yet" in warnings[0].message.args[0]
|
||||
assert converted.get("logprobs") is None
|
||||
|
||||
def test_does_not_include_response_format(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.response_format = {
|
||||
"type": "json_object",
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "response_format is not supported yet" in warnings[0].message.args[0]
|
||||
assert converted.get("response_format") is None
|
||||
|
||||
def test_does_not_include_repetition_penalty(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.repetition_penalty = 1.5
|
||||
|
||||
with pytest.warns(Warning) as warnings:
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert "repetition_penalty is not supported" in warnings[0].message.args[0]
|
||||
assert converted.get("repetition_penalty") is None
|
||||
assert converted.get("frequency_penalty") is None
|
||||
|
||||
def test_includes_stream(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.stream = True
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["stream"] is True
|
||||
|
||||
def test_if_max_tokens_is_0_then_it_is_not_included(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
# 0 is the default value for max_tokens
|
||||
# So we assume that if it's 0, the user didn't set it
|
||||
request.sampling_params.max_tokens = 0
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted.get("max_tokens") is None
|
||||
|
||||
def test_includes_max_tokens_if_set(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.max_tokens = 100
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["max_tokens"] == 100
|
||||
|
||||
def _dummy_chat_completion_request(self):
|
||||
return ChatCompletionRequest(
|
||||
model="Llama-3.2-3B",
|
||||
messages=[UserMessage(content="Hello World")],
|
||||
)
|
||||
|
||||
def test_includes_stratgy(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.strategy = TopPSamplingStrategy(temperature=0.5, top_p=0.95)
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["temperature"] == 0.5
|
||||
assert converted["top_p"] == 0.95
|
||||
|
||||
def test_includes_greedy_strategy(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.strategy = GreedySamplingStrategy()
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["temperature"] == 0.0
|
||||
|
||||
def test_includes_tool_choice(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.tool_config.tool_choice = ToolChoice.required
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["tool_choice"] == "required"
|
||||
|
||||
def test_includes_tools(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.tools = [
|
||||
ToolDefinition(
|
||||
tool_name="get_flight_info",
|
||||
description="Get fight information between two destinations.",
|
||||
parameters={
|
||||
"origin": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The origin airport code. E.g., AU",
|
||||
required=True,
|
||||
),
|
||||
"destination": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The destination airport code. E.g., 'LAX'",
|
||||
required=True,
|
||||
),
|
||||
"passengers": ToolParamDefinition(
|
||||
param_type="array",
|
||||
description="The passengers",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
),
|
||||
ToolDefinition(
|
||||
tool_name="log",
|
||||
description="Calulate the logarithm of a number",
|
||||
parameters={
|
||||
"number": ToolParamDefinition(
|
||||
param_type="float",
|
||||
description="The number to calculate the logarithm of",
|
||||
required=True,
|
||||
),
|
||||
"base": ToolParamDefinition(
|
||||
param_type="integer",
|
||||
description="The base of the logarithm",
|
||||
required=False,
|
||||
default=10,
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["tools"] == [
|
||||
{
|
||||
"type": "function",
|
||||
"function": FunctionDefinition(
|
||||
name="get_flight_info",
|
||||
description="Get fight information between two destinations.",
|
||||
parameters={
|
||||
"origin": {
|
||||
"type": "string",
|
||||
"description": "The origin airport code. E.g., AU",
|
||||
"required": True,
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The destination airport code. E.g., 'LAX'",
|
||||
"required": True,
|
||||
},
|
||||
"passengers": {
|
||||
"type": "array",
|
||||
"description": "The passengers",
|
||||
"required": False,
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": FunctionDefinition(
|
||||
name="log",
|
||||
description="Calulate the logarithm of a number",
|
||||
parameters={
|
||||
"number": {
|
||||
"type": "float",
|
||||
"description": "The number to calculate the logarithm of",
|
||||
"required": True,
|
||||
},
|
||||
"base": {
|
||||
"type": "integer",
|
||||
"description": "The base of the logarithm",
|
||||
"required": False,
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TestConvertNonStreamChatCompletionResponse:
|
||||
def test_returns_response(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].message.content = "Hello World"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.content == "Hello World"
|
||||
|
||||
def test_maps_stop_to_end_of_message(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].finish_reason = "stop"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.stop_reason == StopReason.end_of_turn
|
||||
|
||||
def test_maps_length_to_end_of_message(self):
|
||||
response = self._dummy_chat_completion_response()
|
||||
response.choices[0].finish_reason = "length"
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.stop_reason == StopReason.out_of_tokens
|
||||
|
||||
def test_maps_tool_call_to_end_of_message(self):
|
||||
response = self._dummy_chat_completion_response_with_tool_call()
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.stop_reason == StopReason.end_of_message
|
||||
|
||||
def test_converts_multiple_tool_calls(self):
|
||||
response = self._dummy_chat_completion_response_with_tool_call()
|
||||
response.choices[0].message.tool_calls = [
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_flight_info",
|
||||
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||
),
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id_2",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="log",
|
||||
arguments='{"number": 10, "base": 2}',
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.tool_calls == [
|
||||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name="get_flight_info",
|
||||
arguments={"origin": "AU", "destination": "LAX"},
|
||||
),
|
||||
ToolCall(
|
||||
call_id="tool_call_id_2",
|
||||
tool_name="log",
|
||||
arguments={"number": 10, "base": 2},
|
||||
),
|
||||
]
|
||||
|
||||
def test_converts_unparseable_tool_calls(self):
|
||||
response = self._dummy_chat_completion_response_with_tool_call()
|
||||
response.choices[0].message.tool_calls = [
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="log",
|
||||
arguments="(number=10, base=2)",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert (
|
||||
converted.completion_message.content
|
||||
== '[{"call_id": "tool_call_id", "tool_name": "log", "arguments": "(number=10, base=2)"}]'
|
||||
)
|
||||
|
||||
def _dummy_chat_completion_response(self):
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello World"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
def _dummy_chat_completion_response_with_tool_call(self):
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_flight_info",
|
||||
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
|
||||
class TestConvertStreamChatCompletionResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_stream(self):
|
||||
def chat_completion_stream():
|
||||
messages = ["Hello ", "World ", " !"]
|
||||
for i, message in enumerate(messages):
|
||||
chunk = self._dummy_chat_completion_chunk()
|
||||
chunk.choices[0].delta.content = message
|
||||
yield chunk
|
||||
|
||||
chunk = self._dummy_chat_completion_chunk()
|
||||
chunk.choices[0].delta.content = None
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
yield chunk
|
||||
|
||||
stream = chat_completion_stream()
|
||||
converted = convert_chat_completion_response_stream(stream)
|
||||
|
||||
iter = converted.__aiter__()
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||
assert chunk.event.delta.text == "Hello "
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta.text == "World "
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta.text == " !"
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunk.event.delta.text == ""
|
||||
assert chunk.event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await iter.__anext__()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_tool_calls_stream(self):
|
||||
def tool_call_stream():
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name="get_flight_info",
|
||||
arguments={"origin": "AU", "destination": "LAX"},
|
||||
),
|
||||
ToolCall(
|
||||
call_id="tool_call_id_2",
|
||||
tool_name="log",
|
||||
arguments={"number": 10, "base": 2},
|
||||
),
|
||||
]
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||
chunk.choices[0].delta.tool_calls = [
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
id=tool_call.call_id,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=tool_call.tool_name,
|
||||
arguments=json.dumps(tool_call.arguments),
|
||||
),
|
||||
),
|
||||
]
|
||||
yield chunk
|
||||
|
||||
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||
chunk.choices[0].delta.content = None
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
yield chunk
|
||||
|
||||
stream = tool_call_stream()
|
||||
converted = convert_chat_completion_response_stream(stream)
|
||||
|
||||
iter = converted.__aiter__()
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||
assert chunk.event.delta.tool_call == ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name="get_flight_info",
|
||||
arguments={"origin": "AU", "destination": "LAX"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_tool_calls_stream_with_unparseable_tool_calls(self):
|
||||
def tool_call_stream():
|
||||
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||
chunk.choices[0].delta.tool_calls = [
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
id="tool_call_id",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="get_flight_info",
|
||||
arguments="(origin=AU, destination=LAX)",
|
||||
),
|
||||
),
|
||||
]
|
||||
yield chunk
|
||||
|
||||
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||
chunk.choices[0].delta.content = None
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
yield chunk
|
||||
|
||||
stream = tool_call_stream()
|
||||
converted = convert_chat_completion_response_stream(stream)
|
||||
|
||||
iter = converted.__aiter__()
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||
assert (
|
||||
chunk.event.delta.content
|
||||
== '{"call_id":"tool_call_id","tool_name":"get_flight_info","arguments":"(origin=AU, destination=LAX)"}'
|
||||
)
|
||||
assert chunk.event.delta.parse_status == ToolCallParseStatus.failed
|
||||
|
||||
def _dummy_chat_completion_chunk(self):
|
||||
return ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(role="assistant", content="Hello World"),
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion.chunk",
|
||||
x_groq=None,
|
||||
)
|
||||
|
||||
def _dummy_chat_completion_chunk_with_tool_call(self):
|
||||
return ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
content="Hello World",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="get_flight_info",
|
||||
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion.chunk",
|
||||
x_groq=None,
|
||||
)
|
|
@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
@ -49,10 +50,18 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
class LiteLLMOpenAIMixin(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
NeedsRequestProviderData,
|
||||
):
|
||||
def __init__(self, model_entries) -> None:
|
||||
self.model_entries = model_entries
|
||||
def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str):
|
||||
ModelRegistryHelper.__init__(self, model_entries)
|
||||
self.api_key_from_config = api_key_from_config
|
||||
self.provider_data_api_key_field = provider_data_api_key_field
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
@ -144,8 +153,16 @@ class LiteLLMOpenAIMixin(
|
|||
if request.tool_config.tool_choice:
|
||||
input_dict["tool_choice"] = request.tool_config.tool_choice.value
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
key_field = self.provider_data_api_key_field
|
||||
if provider_data and getattr(provider_data, key_field, None):
|
||||
api_key = getattr(provider_data, key_field)
|
||||
else:
|
||||
api_key = self.api_key_from_config
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"api_key": api_key,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
|
|
|
@ -7,6 +7,7 @@ distribution_spec:
|
|||
- remote::fireworks
|
||||
- remote::anthropic
|
||||
- remote::gemini
|
||||
- remote::groq
|
||||
- inline::sentence-transformers
|
||||
vector_io:
|
||||
- inline::sqlite-vec
|
||||
|
|
|
@ -24,6 +24,8 @@ from llama_stack.providers.remote.inference.fireworks.config import FireworksImp
|
|||
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES
|
||||
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||
from llama_stack.providers.remote.inference.gemini.models import MODEL_ENTRIES as GEMINI_MODEL_ENTRIES
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES as GROQ_MODEL_ENTRIES
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
@ -52,6 +54,11 @@ def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
|||
GEMINI_MODEL_ENTRIES,
|
||||
GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"),
|
||||
),
|
||||
(
|
||||
"groq",
|
||||
GROQ_MODEL_ENTRIES,
|
||||
GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"),
|
||||
),
|
||||
]
|
||||
inference_providers = []
|
||||
default_models = []
|
||||
|
@ -78,14 +85,9 @@ def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
|||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
inference_providers, default_models = get_inference_providers()
|
||||
providers = {
|
||||
"inference": [
|
||||
"remote::openai",
|
||||
"remote::fireworks",
|
||||
"remote::anthropic",
|
||||
"remote::gemini",
|
||||
"inline::sentence-transformers",
|
||||
],
|
||||
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
|
||||
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
|
@ -136,7 +138,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
inference_providers, default_models = get_inference_providers()
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
|
|
@ -29,6 +29,11 @@ providers:
|
|||
provider_type: remote::gemini
|
||||
config:
|
||||
api_key: ${env.GEMINI_API_KEY:}
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY:}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
|
@ -241,6 +246,31 @@ models:
|
|||
provider_id: gemini
|
||||
provider_model_id: gemini/text-embedding-004
|
||||
model_type: embedding
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama3-8b-8192
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama-3.1-8b-instant
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3-70B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama3-70b-8192
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama-3.3-70b-versatile
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama-3.2-3b-preview
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
|
|
|
@ -16,9 +16,8 @@ from llama_stack.models.llama.sku_list import all_registered_models
|
|||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||
from llama_stack.providers.remote.inference.groq.models import _MODEL_ENTRIES
|
||||
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
@ -52,11 +51,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
vector_io_provider = Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
||||
)
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id="sentence-transformers",
|
||||
|
@ -69,11 +63,13 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
||||
default_models = [
|
||||
ModelInput(
|
||||
model_id=core_model_to_hf_repo[m.llama_model],
|
||||
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
|
||||
provider_model_id=m.provider_model_id,
|
||||
provider_id=name,
|
||||
model_type=m.model_type,
|
||||
metadata=m.metadata,
|
||||
)
|
||||
for m in _MODEL_ENTRIES
|
||||
for m in MODEL_ENTRIES
|
||||
]
|
||||
|
||||
default_tool_groups = [
|
||||
|
|
|
@ -93,27 +93,27 @@ models:
|
|||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: llama3-8b-8192
|
||||
provider_model_id: groq/llama3-8b-8192
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: llama-3.1-8b-instant
|
||||
provider_model_id: groq/llama-3.1-8b-instant
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3-70B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: llama3-70b-8192
|
||||
provider_model_id: groq/llama3-70b-8192
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: llama-3.3-70b-versatile
|
||||
provider_model_id: groq/llama-3.3-70b-versatile
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: llama-3.2-3b-preview
|
||||
provider_model_id: groq/llama-3.2-3b-preview
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
|
|
|
@ -117,7 +117,9 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
|
|||
assert len(providers) > 0, "No inference providers found"
|
||||
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
|
||||
|
||||
model_ids = [m.identifier for m in client.models.list()]
|
||||
model_ids = set(m.identifier for m in client.models.list())
|
||||
model_ids.update(m.provider_resource_id for m in client.models.list())
|
||||
|
||||
if text_model_id and text_model_id not in model_ids:
|
||||
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
||||
if vision_model_id and vision_model_id not in model_ids:
|
||||
|
|
|
@ -18,7 +18,7 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
|||
provider_id = models[model_id].provider_id
|
||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||
provider = providers[provider_id]
|
||||
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
|
||||
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini", "remote::groq"):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue