mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-16 01:53:10 +00:00
feat(providers): Groq now uses LiteLLM openai-compat (#1303)
Groq has never supported raw completions anyhow. So this makes it easier to switch it to LiteLLM. All our test suite passes. I also updated all the openai-compat providers so they work with api keys passed from headers. `provider_data` ## Test Plan ```bash LLAMA_STACK_CONFIG=groq \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=groq/llama-3.3-70b-versatile --vision-inference-model="" ``` Also tested (openai, anthropic, gemini) providers. No regressions.
This commit is contained in:
parent
564f0e5f93
commit
928a39d17b
23 changed files with 165 additions and 1004 deletions
|
@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES
|
|||
|
||||
class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: AnthropicConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
await super().shutdown()
|
||||
|
|
|
@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class AnthropicProviderDataValidator(BaseModel):
|
||||
anthropic_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
|
|
|
@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GeminiProviderDataValidator(BaseModel):
|
||||
gemini_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GeminiConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
|
|
|
@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES
|
|||
|
||||
class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: GeminiConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="gemini_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
await super().shutdown()
|
||||
|
|
|
@ -4,23 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import GroqConfig
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqInferenceAdapter
|
||||
|
||||
if not isinstance(config, GroqConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
|
||||
adapter = GroqInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
|
@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Groq models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
|
@ -25,8 +32,8 @@ class GroqConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.groq.com",
|
||||
"api_key": "${env.GROQ_API_KEY}",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
@ -4,130 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
import groq
|
||||
from groq import Groq
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from .groq_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_chat_completion_response,
|
||||
convert_chat_completion_response_stream,
|
||||
)
|
||||
from .models import _MODEL_ENTRIES
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):
|
||||
class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: GroqConfig
|
||||
|
||||
def __init__(self, config: GroqConfig):
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
self._config = config
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
# Groq doesn't support non-chat completion as of time of writing
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
model_id = self.get_provider_model_id(model_id)
|
||||
if model_id == "llama-3.2-3b-preview":
|
||||
warnings.warn(
|
||||
"Groq only contains a preview version for llama-3.2-3b-instruct. "
|
||||
"Preview models aren't recommended for production use. "
|
||||
"They can be discontinued on short notice."
|
||||
"More details: https://console.groq.com/docs/models"
|
||||
)
|
||||
|
||||
request = convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
try:
|
||||
response = self._get_client().chat.completions.create(**request)
|
||||
except groq.BadRequestError as e:
|
||||
if e.body.get("error", {}).get("code") == "tool_use_failed":
|
||||
# For smaller models, Groq may fail to call a tool even when the request is well formed
|
||||
raise ValueError("Groq failed to call a tool", e.body.get("error", {})) from e
|
||||
else:
|
||||
raise e
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
if stream:
|
||||
return convert_chat_completion_response_stream(response)
|
||||
else:
|
||||
return convert_chat_completion_response(response)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_client(self) -> Groq:
|
||||
if self._config.api_key is not None:
|
||||
return Groq(api_key=self._config.api_key)
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.groq_api_key:
|
||||
raise ValueError(
|
||||
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
|
||||
)
|
||||
return Groq(api_key=provider_data.groq_api_key)
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
|
@ -1,245 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import AsyncGenerator, Literal
|
||||
|
||||
from groq import Stream
|
||||
from groq.types.chat.chat_completion import ChatCompletion
|
||||
from groq.types.chat.chat_completion_assistant_message_param import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
)
|
||||
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||
from groq.types.chat.chat_completion_system_message_param import (
|
||||
ChatCompletionSystemMessageParam,
|
||||
)
|
||||
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
||||
from groq.types.chat.chat_completion_user_message_param import (
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from groq.types.chat.completion_create_params import CompletionCreateParams
|
||||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
Message,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
UnparseableToolCall,
|
||||
convert_tool_call,
|
||||
get_sampling_strategy_options,
|
||||
)
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
) -> CompletionCreateParams:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to a Groq API-compatible dictionary.
|
||||
Warns client if request contains unsupported features.
|
||||
"""
|
||||
|
||||
if request.logprobs:
|
||||
# Groq doesn't support logprobs at the time of writing
|
||||
warnings.warn("logprobs are not supported yet")
|
||||
|
||||
if request.response_format:
|
||||
# Groq's JSON mode is beta at the time of writing
|
||||
warnings.warn("response_format is not supported yet")
|
||||
|
||||
if request.sampling_params.repetition_penalty != 1.0:
|
||||
# groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty
|
||||
# seem to have different semantics
|
||||
# frequency_penalty defaults to 0 is a float between -2.0 and 2.0
|
||||
# repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0
|
||||
# so we exclude it for now
|
||||
warnings.warn("repetition_penalty is not supported")
|
||||
|
||||
if request.tool_config.tool_prompt_format != ToolPromptFormat.json:
|
||||
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||
|
||||
sampling_options = get_sampling_strategy_options(request.sampling_params)
|
||||
return CompletionCreateParams(
|
||||
model=request.model,
|
||||
messages=[_convert_message(message) for message in request.messages],
|
||||
logprobs=None,
|
||||
frequency_penalty=None,
|
||||
stream=request.stream,
|
||||
max_tokens=request.sampling_params.max_tokens or None,
|
||||
temperature=sampling_options.get("temperature", 1.0),
|
||||
top_p=sampling_options.get("top_p", 1.0),
|
||||
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||
tool_choice=(request.tool_config.tool_choice.value if request.tool_config.tool_choice else None),
|
||||
)
|
||||
|
||||
|
||||
def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
||||
if message.role == "system":
|
||||
return ChatCompletionSystemMessageParam(role="system", content=message.content)
|
||||
elif message.role == "user":
|
||||
return ChatCompletionUserMessageParam(role="user", content=message.content)
|
||||
elif message.role == "assistant":
|
||||
return ChatCompletionAssistantMessageParam(role="assistant", content=message.content)
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
|
||||
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
|
||||
# Groq requires a description for function tools
|
||||
if tool_definition.description is None:
|
||||
raise AssertionError("tool_definition.description is required")
|
||||
|
||||
tool_parameters = tool_definition.parameters or {}
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=tool_definition.tool_name,
|
||||
description=tool_definition.description,
|
||||
parameters={key: _convert_groq_tool_parameter(param) for key, param in tool_parameters.items()},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
|
||||
param = {
|
||||
"type": tool_parameter.param_type,
|
||||
}
|
||||
if tool_parameter.description is not None:
|
||||
param["description"] = tool_parameter.description
|
||||
if tool_parameter.required is not None:
|
||||
param["required"] = tool_parameter.required
|
||||
if tool_parameter.default is not None:
|
||||
param["default"] = tool_parameter.default
|
||||
return param
|
||||
|
||||
|
||||
def convert_chat_completion_response(
|
||||
response: ChatCompletion,
|
||||
) -> ChatCompletionResponse:
|
||||
# groq only supports n=1 at time of writing, so there is only one choice
|
||||
choice = response.choices[0]
|
||||
if choice.finish_reason == "tool_calls":
|
||||
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
stop_reason=StopReason.end_of_message,
|
||||
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
# Otherwise, return tool calls as normal
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
tool_calls=tool_calls,
|
||||
stop_reason=StopReason.end_of_message,
|
||||
# Content is not optional
|
||||
content="",
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _map_finish_reason_to_stop_reason(
|
||||
finish_reason: Literal["stop", "length", "tool_calls"],
|
||||
) -> StopReason:
|
||||
"""
|
||||
Convert a Groq chat completion finish_reason to a StopReason.
|
||||
|
||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||
- stop -> model hit a natural stop point or a provided stop sequence
|
||||
- length -> maximum number of tokens specified in the request was reached
|
||||
- tool_calls -> model called a tool
|
||||
"""
|
||||
if finish_reason == "stop":
|
||||
return StopReason.end_of_turn
|
||||
elif finish_reason == "length":
|
||||
return StopReason.out_of_tokens
|
||||
elif finish_reason == "tool_calls":
|
||||
return StopReason.end_of_message
|
||||
else:
|
||||
raise ValueError(f"Invalid finish reason: {finish_reason}")
|
||||
|
||||
|
||||
async def convert_chat_completion_response_stream(
|
||||
stream: Stream[ChatCompletionChunk],
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
event_type = ChatCompletionResponseEventType.start
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
|
||||
if choice.finish_reason:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=None,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
)
|
||||
)
|
||||
elif choice.delta.tool_calls:
|
||||
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
|
||||
if len(choice.delta.tool_calls) > 1:
|
||||
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
|
||||
|
||||
# We assume Groq produces fully formed tool calls for each chunk
|
||||
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
||||
if isinstance(tool_call, ToolCall):
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Otherwise it's an UnparseableToolCall - return the raw tool call
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call.model_dump_json(),
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=None,
|
||||
)
|
||||
)
|
||||
event_type = ChatCompletionResponseEventType.progress
|
|
@ -7,21 +7,21 @@
|
|||
from llama_stack.models.llama.sku_list import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import build_model_entry
|
||||
|
||||
_MODEL_ENTRIES = [
|
||||
MODEL_ENTRIES = [
|
||||
build_model_entry(
|
||||
"llama3-8b-8192",
|
||||
"groq/llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama-3.1-8b-instant",
|
||||
"groq/llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3-70b-8192",
|
||||
"groq/llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama-3.3-70b-versatile",
|
||||
"groq/llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# Groq only contains a preview version for llama-3.2-3b
|
||||
|
@ -29,7 +29,7 @@ _MODEL_ENTRIES = [
|
|||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_model_entry(
|
||||
"llama-3.2-3b-preview",
|
||||
"groq/llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
]
|
||||
|
|
|
@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class OpenAIProviderDataValidator(BaseModel):
|
||||
openai_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
|
|
|
@ -12,11 +12,16 @@ from .models import MODEL_ENTRIES
|
|||
|
||||
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="openai_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
await super().shutdown()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue