diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index df7610935..3e43a844c 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -4,8 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from openai import AsyncOpenAI from together import AsyncTogether from llama_stack.apis.common.content_types import ( @@ -30,12 +31,14 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, + prepare_openai_completion_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, @@ -60,6 +63,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi ModelRegistryHelper.__init__(self, MODEL_ENTRIES) self.config = config self._client = None + self._openai_client = None async def initialize(self) -> None: pass @@ -110,6 +114,15 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi self._client = AsyncTogether(api_key=together_api_key) return self._client + def _get_openai_client(self) -> AsyncOpenAI: + if not self._openai_client: + together_client = self._get_client().client + self._openai_client = AsyncOpenAI( + base_url=together_client.base_url, + api_key=together_client.api_key, + ) + return self._openai_client + async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) client = self._get_client() @@ -243,3 +256,99 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi ) embeddings = [item.embedding for item in r.data] return EmbeddingsResponse(embeddings=embeddings) + + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + model_obj = await self._get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().completions.create(**params) # type: ignore + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + model_obj = await self._get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().chat.completions.create(**params) # type: ignore diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index d7555c39f..66fb986f9 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -59,6 +59,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, convert_tool_call, get_sampling_options, + prepare_openai_completion_params, process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, @@ -441,29 +442,25 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): user: Optional[str] = None, ) -> OpenAICompletion: model_obj = await self._get_model(model) - params = { - k: v - for k, v in { - "model": model_obj.provider_resource_id, - "prompt": prompt, - "best_of": best_of, - "echo": echo, - "frequency_penalty": frequency_penalty, - "logit_bias": logit_bias, - "logprobs": logprobs, - "max_tokens": max_tokens, - "n": n, - "presence_penalty": presence_penalty, - "seed": seed, - "stop": stop, - "stream": stream, - "stream_options": stream_options, - "temperature": temperature, - "top_p": top_p, - "user": user, - }.items() - if v is not None - } + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) return await self.client.completions.create(**params) # type: ignore async def openai_chat_completion( @@ -493,33 +490,29 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): user: Optional[str] = None, ) -> OpenAIChatCompletion: model_obj = await self._get_model(model) - params = { - k: v - for k, v in { - "model": model_obj.provider_resource_id, - "messages": messages, - "frequency_penalty": frequency_penalty, - "function_call": function_call, - "functions": functions, - "logit_bias": logit_bias, - "logprobs": logprobs, - "max_completion_tokens": max_completion_tokens, - "max_tokens": max_tokens, - "n": n, - "parallel_tool_calls": parallel_tool_calls, - "presence_penalty": presence_penalty, - "response_format": response_format, - "seed": seed, - "stop": stop, - "stream": stream, - "stream_options": stream_options, - "temperature": temperature, - "tool_choice": tool_choice, - "tools": tools, - "top_logprobs": top_logprobs, - "top_p": top_p, - "user": user, - }.items() - if v is not None - } + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) return await self.client.chat.completions.create(**params) # type: ignore diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index bd1eb3978..8111e4463 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, AsyncIterator, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union import litellm @@ -30,6 +30,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.models.models import Model from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger @@ -40,6 +41,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( convert_openai_chat_completion_stream, convert_tooldef_to_openai_tool, get_sampling_options, + prepare_openai_completion_params, ) from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, @@ -245,3 +247,99 @@ class LiteLLMOpenAIMixin( embeddings = [data["embedding"] for data in response["data"]] return EmbeddingsResponse(embeddings=embeddings) + + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + model_obj = await self._get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return litellm.text_completion(**params) + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + model_obj = await self._get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + return litellm.completion(**params) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index d9091d5c8..ea02de573 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1052,6 +1052,11 @@ async def convert_openai_chat_completion_stream( ) +async def prepare_openai_completion_params(**params): + completion_params = {k: v for k, v in params.items() if v is not None} + return completion_params + + class OpenAICompletionUnsupportedMixin: async def openai_completion( self,