llama-stack/llama_stack/providers/remote/inference/litellm/litellm.py
2025-02-03 06:15:09 -08:00

116 lines
No EOL
4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import AsyncIterator, List, Optional, Union, Any
from litellm import completion as litellm_completion
from litellm.types.utils import ModelResponse
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat, StopReason
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionMessage,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
ToolChoice,
)
# from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.remote.inference.litellm.config import LitellmConfig
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
_MODEL_ALIASES = [
build_model_alias(
"gpt-4o", # provider_model_id
"gpt-4o", # model_descriptor
),
]
class LitellmInferenceAdapter(Inference, ModelRegistryHelper):
_config: LitellmConfig
def __init__(self, config: LitellmConfig):
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
self._config = config
def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
# litellm doesn't support non-chat completion as of time of writing
raise NotImplementedError()
def _messages_to_litellm_messages(
self,
messages: List[Message],
) -> list[dict[str, Any]]:
litellm_messages = []
for message in messages:
lm_message = {
"role": message.role,
"content": message.content,
}
litellm_messages.append(lm_message)
return litellm_messages
def _convert_to_llama_stack_response(
self,
litellm_response: ModelResponse,
) -> ChatCompletionResponse:
assert litellm_response.choices is not None
assert len(litellm_response.choices) == 1
message = litellm_response.choices[0].message
completion_message = CompletionMessage(content=message["content"], role=message["role"], stop_reason=StopReason.end_of_message, tool_calls=[])
return ChatCompletionResponse(completion_message=completion_message)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
assert stream is False, "streaming not supported"
model_id = self.get_provider_model_id(model_id)
response = litellm_completion(
model=model_id,
custom_llm_provider=self._config.llm_provider,
messages=self._messages_to_litellm_messages(messages),
api_key=self._config.openai_api_key,
)
return self._convert_to_llama_stack_response(response)
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()