forked from phoenix-oss/llama-stack-mirror
116 lines
No EOL
4 KiB
Python
116 lines
No EOL
4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
from typing import AsyncIterator, List, Optional, Union, Any
|
|
|
|
from litellm import completion as litellm_completion
|
|
from litellm.types.utils import ModelResponse
|
|
|
|
from llama_models.datatypes import SamplingParams
|
|
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat, StopReason
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionResponse,
|
|
CompletionMessage,
|
|
CompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
Inference,
|
|
InterleavedContent,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
ToolChoice,
|
|
)
|
|
# from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
|
from llama_stack.providers.remote.inference.litellm.config import LitellmConfig
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
build_model_alias,
|
|
ModelRegistryHelper,
|
|
)
|
|
|
|
_MODEL_ALIASES = [
|
|
build_model_alias(
|
|
"gpt-4o", # provider_model_id
|
|
"gpt-4o", # model_descriptor
|
|
),
|
|
]
|
|
|
|
class LitellmInferenceAdapter(Inference, ModelRegistryHelper):
|
|
_config: LitellmConfig
|
|
|
|
def __init__(self, config: LitellmConfig):
|
|
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
|
self._config = config
|
|
|
|
def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
|
# litellm doesn't support non-chat completion as of time of writing
|
|
raise NotImplementedError()
|
|
|
|
def _messages_to_litellm_messages(
|
|
self,
|
|
messages: List[Message],
|
|
) -> list[dict[str, Any]]:
|
|
litellm_messages = []
|
|
for message in messages:
|
|
lm_message = {
|
|
"role": message.role,
|
|
"content": message.content,
|
|
}
|
|
litellm_messages.append(lm_message)
|
|
return litellm_messages
|
|
|
|
def _convert_to_llama_stack_response(
|
|
self,
|
|
litellm_response: ModelResponse,
|
|
) -> ChatCompletionResponse:
|
|
assert litellm_response.choices is not None
|
|
assert len(litellm_response.choices) == 1
|
|
message = litellm_response.choices[0].message
|
|
completion_message = CompletionMessage(content=message["content"], role=message["role"], stop_reason=StopReason.end_of_message, tool_calls=[])
|
|
return ChatCompletionResponse(completion_message=completion_message)
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[
|
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
]:
|
|
assert stream is False, "streaming not supported"
|
|
model_id = self.get_provider_model_id(model_id)
|
|
response = litellm_completion(
|
|
model=model_id,
|
|
custom_llm_provider=self._config.llm_provider,
|
|
messages=self._messages_to_litellm_messages(messages),
|
|
api_key=self._config.openai_api_key,
|
|
)
|
|
|
|
return self._convert_to_llama_stack_response(response)
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[InterleavedContent],
|
|
) -> EmbeddingsResponse:
|
|
raise NotImplementedError()
|
|
|