forked from phoenix-oss/llama-stack-mirror
Litellm support in llama stack:
This commit is contained in:
parent
15dcc4ea5e
commit
9e0c8a82cb
6 changed files with 359 additions and 0 deletions
116
llama_stack/providers/remote/inference/litellm/litellm.py
Normal file
116
llama_stack/providers/remote/inference/litellm/litellm.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import AsyncIterator, List, Optional, Union, Any
|
||||
|
||||
from litellm import completion as litellm_completion
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from llama_models.datatypes import SamplingParams
|
||||
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat, StopReason
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionMessage,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ToolChoice,
|
||||
)
|
||||
# from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.remote.inference.litellm.config import LitellmConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
||||
_MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"gpt-4o", # provider_model_id
|
||||
"gpt-4o", # model_descriptor
|
||||
),
|
||||
]
|
||||
|
||||
class LitellmInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
_config: LitellmConfig
|
||||
|
||||
def __init__(self, config: LitellmConfig):
|
||||
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||
self._config = config
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
# litellm doesn't support non-chat completion as of time of writing
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_litellm_messages(
|
||||
self,
|
||||
messages: List[Message],
|
||||
) -> list[dict[str, Any]]:
|
||||
litellm_messages = []
|
||||
for message in messages:
|
||||
lm_message = {
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
litellm_messages.append(lm_message)
|
||||
return litellm_messages
|
||||
|
||||
def _convert_to_llama_stack_response(
|
||||
self,
|
||||
litellm_response: ModelResponse,
|
||||
) -> ChatCompletionResponse:
|
||||
assert litellm_response.choices is not None
|
||||
assert len(litellm_response.choices) == 1
|
||||
message = litellm_response.choices[0].message
|
||||
completion_message = CompletionMessage(content=message["content"], role=message["role"], stop_reason=StopReason.end_of_message, tool_calls=[])
|
||||
return ChatCompletionResponse(completion_message=completion_message)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
assert stream is False, "streaming not supported"
|
||||
model_id = self.get_provider_model_id(model_id)
|
||||
response = litellm_completion(
|
||||
model=model_id,
|
||||
custom_llm_provider=self._config.llm_provider,
|
||||
messages=self._messages_to_litellm_messages(messages),
|
||||
api_key=self._config.openai_api_key,
|
||||
)
|
||||
|
||||
return self._convert_to_llama_stack_response(response)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue