forked from phoenix-oss/llama-stack-mirror
		
	note: the openai provider exposes the litellm specific model names to the user. this change is compatible with that. the litellm names should be deprecated.
		
			
				
	
	
		
			391 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			391 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| from collections.abc import AsyncGenerator, AsyncIterator
 | |
| from typing import Any
 | |
| 
 | |
| import litellm
 | |
| 
 | |
| from llama_stack.apis.common.content_types import (
 | |
|     InterleavedContent,
 | |
|     InterleavedContentItem,
 | |
| )
 | |
| from llama_stack.apis.inference import (
 | |
|     ChatCompletionRequest,
 | |
|     ChatCompletionResponse,
 | |
|     ChatCompletionResponseStreamChunk,
 | |
|     EmbeddingsResponse,
 | |
|     EmbeddingTaskType,
 | |
|     Inference,
 | |
|     JsonSchemaResponseFormat,
 | |
|     LogProbConfig,
 | |
|     Message,
 | |
|     ResponseFormat,
 | |
|     SamplingParams,
 | |
|     TextTruncation,
 | |
|     ToolChoice,
 | |
|     ToolConfig,
 | |
|     ToolDefinition,
 | |
|     ToolPromptFormat,
 | |
| )
 | |
| from llama_stack.apis.inference.inference import (
 | |
|     OpenAIChatCompletion,
 | |
|     OpenAIChatCompletionChunk,
 | |
|     OpenAICompletion,
 | |
|     OpenAIMessageParam,
 | |
|     OpenAIResponseFormatParam,
 | |
| )
 | |
| from llama_stack.apis.models.models import Model
 | |
| from llama_stack.distribution.request_headers import NeedsRequestProviderData
 | |
| from llama_stack.log import get_logger
 | |
| from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
 | |
| from llama_stack.providers.utils.inference.openai_compat import (
 | |
|     convert_message_to_openai_dict_new,
 | |
|     convert_openai_chat_completion_choice,
 | |
|     convert_openai_chat_completion_stream,
 | |
|     convert_tooldef_to_openai_tool,
 | |
|     get_sampling_options,
 | |
|     prepare_openai_completion_params,
 | |
| )
 | |
| from llama_stack.providers.utils.inference.prompt_adapter import (
 | |
|     interleaved_content_as_str,
 | |
| )
 | |
| 
 | |
| logger = get_logger(name=__name__, category="inference")
 | |
| 
 | |
| 
 | |
| class LiteLLMOpenAIMixin(
 | |
|     ModelRegistryHelper,
 | |
|     Inference,
 | |
|     NeedsRequestProviderData,
 | |
| ):
 | |
|     # TODO: avoid exposing the litellm specific model names to the user.
 | |
|     #       potential change: add a prefix param that gets added to the model name
 | |
|     #                         when calling litellm.
 | |
|     def __init__(
 | |
|         self,
 | |
|         model_entries,
 | |
|         api_key_from_config: str | None,
 | |
|         provider_data_api_key_field: str,
 | |
|         openai_compat_api_base: str | None = None,
 | |
|     ):
 | |
|         ModelRegistryHelper.__init__(self, model_entries)
 | |
|         self.api_key_from_config = api_key_from_config
 | |
|         self.provider_data_api_key_field = provider_data_api_key_field
 | |
|         self.api_base = openai_compat_api_base
 | |
| 
 | |
|         if openai_compat_api_base:
 | |
|             self.is_openai_compat = True
 | |
|         else:
 | |
|             self.is_openai_compat = False
 | |
| 
 | |
|     async def initialize(self):
 | |
|         pass
 | |
| 
 | |
|     async def shutdown(self):
 | |
|         pass
 | |
| 
 | |
|     async def register_model(self, model: Model) -> Model:
 | |
|         model_id = self.get_provider_model_id(model.provider_resource_id)
 | |
|         if model_id is None:
 | |
|             raise ValueError(f"Unsupported model: {model.provider_resource_id}")
 | |
|         return model
 | |
| 
 | |
|     def get_litellm_model_name(self, model_id: str) -> str:
 | |
|         # users may be using openai/ prefix in their model names. the openai/models.py did this by default.
 | |
|         # model_id.startswith("openai/") is for backwards compatibility.
 | |
|         return "openai/" + model_id if self.is_openai_compat and not model_id.startswith("openai/") else model_id
 | |
| 
 | |
|     async def completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         content: InterleavedContent,
 | |
|         sampling_params: SamplingParams | None = None,
 | |
|         response_format: ResponseFormat | None = None,
 | |
|         stream: bool | None = False,
 | |
|         logprobs: LogProbConfig | None = None,
 | |
|     ) -> AsyncGenerator:
 | |
|         raise NotImplementedError("LiteLLM does not support completion requests")
 | |
| 
 | |
|     async def chat_completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         messages: list[Message],
 | |
|         sampling_params: SamplingParams | None = None,
 | |
|         tools: list[ToolDefinition] | None = None,
 | |
|         tool_choice: ToolChoice | None = ToolChoice.auto,
 | |
|         tool_prompt_format: ToolPromptFormat | None = None,
 | |
|         response_format: ResponseFormat | None = None,
 | |
|         stream: bool | None = False,
 | |
|         logprobs: LogProbConfig | None = None,
 | |
|         tool_config: ToolConfig | None = None,
 | |
|     ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
 | |
|         if sampling_params is None:
 | |
|             sampling_params = SamplingParams()
 | |
| 
 | |
|         model = await self.model_store.get_model(model_id)
 | |
|         request = ChatCompletionRequest(
 | |
|             model=model.provider_resource_id,
 | |
|             messages=messages,
 | |
|             sampling_params=sampling_params,
 | |
|             tools=tools or [],
 | |
|             response_format=response_format,
 | |
|             stream=stream,
 | |
|             logprobs=logprobs,
 | |
|             tool_config=tool_config,
 | |
|         )
 | |
| 
 | |
|         params = await self._get_params(request)
 | |
|         params["model"] = self.get_litellm_model_name(params["model"])
 | |
| 
 | |
|         logger.debug(f"params to litellm (openai compat): {params}")
 | |
|         # unfortunately, we need to use synchronous litellm.completion here because litellm
 | |
|         # caches various httpx.client objects in a non-eventloop aware manner
 | |
|         response = litellm.completion(**params)
 | |
|         if stream:
 | |
|             return self._stream_chat_completion(response)
 | |
|         else:
 | |
|             return convert_openai_chat_completion_choice(response.choices[0])
 | |
| 
 | |
|     async def _stream_chat_completion(
 | |
|         self, response: litellm.ModelResponse
 | |
|     ) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
 | |
|         async def _stream_generator():
 | |
|             for chunk in response:
 | |
|                 yield chunk
 | |
| 
 | |
|         async for chunk in convert_openai_chat_completion_stream(
 | |
|             _stream_generator(), enable_incremental_tool_calls=True
 | |
|         ):
 | |
|             yield chunk
 | |
| 
 | |
|     def _add_additional_properties_recursive(self, schema):
 | |
|         """
 | |
|         Recursively add additionalProperties: False to all object schemas
 | |
|         """
 | |
|         if isinstance(schema, dict):
 | |
|             if schema.get("type") == "object":
 | |
|                 schema["additionalProperties"] = False
 | |
| 
 | |
|                 # Add required field with all property keys if properties exist
 | |
|                 if "properties" in schema and schema["properties"]:
 | |
|                     schema["required"] = list(schema["properties"].keys())
 | |
| 
 | |
|             if "properties" in schema:
 | |
|                 for prop_schema in schema["properties"].values():
 | |
|                     self._add_additional_properties_recursive(prop_schema)
 | |
| 
 | |
|             for key in ["anyOf", "allOf", "oneOf"]:
 | |
|                 if key in schema:
 | |
|                     for sub_schema in schema[key]:
 | |
|                         self._add_additional_properties_recursive(sub_schema)
 | |
| 
 | |
|             if "not" in schema:
 | |
|                 self._add_additional_properties_recursive(schema["not"])
 | |
| 
 | |
|             # Handle $defs/$ref
 | |
|             if "$defs" in schema:
 | |
|                 for def_schema in schema["$defs"].values():
 | |
|                     self._add_additional_properties_recursive(def_schema)
 | |
| 
 | |
|         return schema
 | |
| 
 | |
|     async def _get_params(self, request: ChatCompletionRequest) -> dict:
 | |
|         input_dict = {}
 | |
| 
 | |
|         input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
 | |
|         if fmt := request.response_format:
 | |
|             if not isinstance(fmt, JsonSchemaResponseFormat):
 | |
|                 raise ValueError(
 | |
|                     f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
 | |
|                 )
 | |
| 
 | |
|             fmt = fmt.json_schema
 | |
|             name = fmt["title"]
 | |
|             del fmt["title"]
 | |
|             fmt["additionalProperties"] = False
 | |
| 
 | |
|             # Apply additionalProperties: False recursively to all objects
 | |
|             fmt = self._add_additional_properties_recursive(fmt)
 | |
| 
 | |
|             input_dict["response_format"] = {
 | |
|                 "type": "json_schema",
 | |
|                 "json_schema": {
 | |
|                     "name": name,
 | |
|                     "schema": fmt,
 | |
|                     "strict": True,
 | |
|                 },
 | |
|             }
 | |
|         if request.tools:
 | |
|             input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
 | |
|             if request.tool_config.tool_choice:
 | |
|                 input_dict["tool_choice"] = (
 | |
|                     request.tool_config.tool_choice.value
 | |
|                     if isinstance(request.tool_config.tool_choice, ToolChoice)
 | |
|                     else request.tool_config.tool_choice
 | |
|                 )
 | |
| 
 | |
|         return {
 | |
|             "model": request.model,
 | |
|             "api_key": self.get_api_key(),
 | |
|             "api_base": self.api_base,
 | |
|             **input_dict,
 | |
|             "stream": request.stream,
 | |
|             **get_sampling_options(request.sampling_params),
 | |
|         }
 | |
| 
 | |
|     def get_api_key(self) -> str:
 | |
|         provider_data = self.get_request_provider_data()
 | |
|         key_field = self.provider_data_api_key_field
 | |
|         if provider_data and getattr(provider_data, key_field, None):
 | |
|             api_key = getattr(provider_data, key_field)
 | |
|         else:
 | |
|             api_key = self.api_key_from_config
 | |
|         return api_key
 | |
| 
 | |
|     async def embeddings(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         contents: list[str] | list[InterleavedContentItem],
 | |
|         text_truncation: TextTruncation | None = TextTruncation.none,
 | |
|         output_dimension: int | None = None,
 | |
|         task_type: EmbeddingTaskType | None = None,
 | |
|     ) -> EmbeddingsResponse:
 | |
|         model = await self.model_store.get_model(model_id)
 | |
| 
 | |
|         response = litellm.embedding(
 | |
|             model=self.get_litellm_model_name(model.provider_resource_id),
 | |
|             input=[interleaved_content_as_str(content) for content in contents],
 | |
|         )
 | |
| 
 | |
|         embeddings = [data["embedding"] for data in response["data"]]
 | |
|         return EmbeddingsResponse(embeddings=embeddings)
 | |
| 
 | |
|     async def openai_completion(
 | |
|         self,
 | |
|         model: str,
 | |
|         prompt: str | list[str] | list[int] | list[list[int]],
 | |
|         best_of: int | None = None,
 | |
|         echo: bool | None = None,
 | |
|         frequency_penalty: float | None = None,
 | |
|         logit_bias: dict[str, float] | None = None,
 | |
|         logprobs: bool | None = None,
 | |
|         max_tokens: int | None = None,
 | |
|         n: int | None = None,
 | |
|         presence_penalty: float | None = None,
 | |
|         seed: int | None = None,
 | |
|         stop: str | list[str] | None = None,
 | |
|         stream: bool | None = None,
 | |
|         stream_options: dict[str, Any] | None = None,
 | |
|         temperature: float | None = None,
 | |
|         top_p: float | None = None,
 | |
|         user: str | None = None,
 | |
|         guided_choice: list[str] | None = None,
 | |
|         prompt_logprobs: int | None = None,
 | |
|     ) -> OpenAICompletion:
 | |
|         model_obj = await self.model_store.get_model(model)
 | |
|         params = await prepare_openai_completion_params(
 | |
|             model=self.get_litellm_model_name(model_obj.provider_resource_id),
 | |
|             prompt=prompt,
 | |
|             best_of=best_of,
 | |
|             echo=echo,
 | |
|             frequency_penalty=frequency_penalty,
 | |
|             logit_bias=logit_bias,
 | |
|             logprobs=logprobs,
 | |
|             max_tokens=max_tokens,
 | |
|             n=n,
 | |
|             presence_penalty=presence_penalty,
 | |
|             seed=seed,
 | |
|             stop=stop,
 | |
|             stream=stream,
 | |
|             stream_options=stream_options,
 | |
|             temperature=temperature,
 | |
|             top_p=top_p,
 | |
|             user=user,
 | |
|             guided_choice=guided_choice,
 | |
|             prompt_logprobs=prompt_logprobs,
 | |
|             api_key=self.get_api_key(),
 | |
|             api_base=self.api_base,
 | |
|         )
 | |
|         return await litellm.atext_completion(**params)
 | |
| 
 | |
|     async def openai_chat_completion(
 | |
|         self,
 | |
|         model: str,
 | |
|         messages: list[OpenAIMessageParam],
 | |
|         frequency_penalty: float | None = None,
 | |
|         function_call: str | dict[str, Any] | None = None,
 | |
|         functions: list[dict[str, Any]] | None = None,
 | |
|         logit_bias: dict[str, float] | None = None,
 | |
|         logprobs: bool | None = None,
 | |
|         max_completion_tokens: int | None = None,
 | |
|         max_tokens: int | None = None,
 | |
|         n: int | None = None,
 | |
|         parallel_tool_calls: bool | None = None,
 | |
|         presence_penalty: float | None = None,
 | |
|         response_format: OpenAIResponseFormatParam | None = None,
 | |
|         seed: int | None = None,
 | |
|         stop: str | list[str] | None = None,
 | |
|         stream: bool | None = None,
 | |
|         stream_options: dict[str, Any] | None = None,
 | |
|         temperature: float | None = None,
 | |
|         tool_choice: str | dict[str, Any] | None = None,
 | |
|         tools: list[dict[str, Any]] | None = None,
 | |
|         top_logprobs: int | None = None,
 | |
|         top_p: float | None = None,
 | |
|         user: str | None = None,
 | |
|     ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
 | |
|         model_obj = await self.model_store.get_model(model)
 | |
|         params = await prepare_openai_completion_params(
 | |
|             model=self.get_litellm_model_name(model_obj.provider_resource_id),
 | |
|             messages=messages,
 | |
|             frequency_penalty=frequency_penalty,
 | |
|             function_call=function_call,
 | |
|             functions=functions,
 | |
|             logit_bias=logit_bias,
 | |
|             logprobs=logprobs,
 | |
|             max_completion_tokens=max_completion_tokens,
 | |
|             max_tokens=max_tokens,
 | |
|             n=n,
 | |
|             parallel_tool_calls=parallel_tool_calls,
 | |
|             presence_penalty=presence_penalty,
 | |
|             response_format=response_format,
 | |
|             seed=seed,
 | |
|             stop=stop,
 | |
|             stream=stream,
 | |
|             stream_options=stream_options,
 | |
|             temperature=temperature,
 | |
|             tool_choice=tool_choice,
 | |
|             tools=tools,
 | |
|             top_logprobs=top_logprobs,
 | |
|             top_p=top_p,
 | |
|             user=user,
 | |
|             api_key=self.get_api_key(),
 | |
|             api_base=self.api_base,
 | |
|         )
 | |
|         return await litellm.acompletion(**params)
 | |
| 
 | |
|     async def batch_completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         content_batch: list[InterleavedContent],
 | |
|         sampling_params: SamplingParams | None = None,
 | |
|         response_format: ResponseFormat | None = None,
 | |
|         logprobs: LogProbConfig | None = None,
 | |
|     ):
 | |
|         raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
 | |
| 
 | |
|     async def batch_chat_completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         messages_batch: list[list[Message]],
 | |
|         sampling_params: SamplingParams | None = None,
 | |
|         tools: list[ToolDefinition] | None = None,
 | |
|         tool_config: ToolConfig | None = None,
 | |
|         response_format: ResponseFormat | None = None,
 | |
|         logprobs: LogProbConfig | None = None,
 | |
|     ):
 | |
|         raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
 |