From a083465ba4ea4034e31da35142ab983058e2cbd3 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Mon, 28 Apr 2025 09:21:23 -0400 Subject: [PATCH] Add openai completion/chat completion --- .../remote/inference/lmstudio/lmstudio.py | 145 ++++++++++++++++-- 1 file changed, 136 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/remote/inference/lmstudio/lmstudio.py b/llama_stack/providers/remote/inference/lmstudio/lmstudio.py index 8f7377bc9..59ec68fe3 100644 --- a/llama_stack/providers/remote/inference/lmstudio/lmstudio.py +++ b/llama_stack/providers/remote/inference/lmstudio/lmstudio.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncIterator, List, Optional, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -19,6 +19,11 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, ResponseFormat, SamplingParams, TextTruncation, @@ -51,17 +56,139 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate): def client(self) -> LMStudioClient: return LMStudioClient(url=self.url) - async def batch_chat_completion(self, *args, **kwargs): - raise NotImplementedError("Batch chat completion not supported by LM Studio Provider") + async def batch_completion( + self, + model_id: str, + content_batch: List[InterleavedContent], + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch completion is not supported by LM Studio Provider") - async def batch_completion(self, *args, **kwargs): - raise NotImplementedError("Batch completion not supported by LM Studio Provider") + async def batch_chat_completion( + self, + model_id: str, + messages_batch: List[List[Message]], + sampling_params: Optional[SamplingParams] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_config: Optional[ToolConfig] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch completion is not supported by LM Studio Provider") - async def openai_chat_completion(self, *args, **kwargs): - raise NotImplementedError("OpenAI chat completion not supported by LM Studio Provider") + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + if self.model_store is None: + raise ValueError("ModelStore is not initialized") + model_obj = await self.model_store.get_model(model) + params = { + k: v + for k, v in { + "model": model_obj.provider_resource_id, + "messages": messages, + "frequency_penalty": frequency_penalty, + "function_call": function_call, + "functions": functions, + "logit_bias": logit_bias, + "logprobs": logprobs, + "max_completion_tokens": max_completion_tokens, + "max_tokens": max_tokens, + "n": n, + "parallel_tool_calls": parallel_tool_calls, + "presence_penalty": presence_penalty, + "response_format": response_format, + "seed": seed, + "stop": stop, + "stream": stream, + "stream_options": stream_options, + "temperature": temperature, + "tool_choice": tool_choice, + "tools": tools, + "top_logprobs": top_logprobs, + "top_p": top_p, + "user": user, + }.items() + if v is not None + } + return await self.openai_client.chat.completions.create(**params) # type: ignore - async def openai_completion(self, *args, **kwargs): - raise NotImplementedError("OpenAI completion not supported by LM Studio Provider") + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + if not isinstance(prompt, str): + raise ValueError("LM Studio does not support non-string prompts for completion") + if self.model_store is None: + raise ValueError("ModelStore is not initialized") + model_obj = await self.model_store.get_model(model) + params = { + k: v + for k, v in { + "model": model_obj.provider_resource_id, + "prompt": prompt, + "best_of": best_of, + "echo": echo, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "logprobs": logprobs, + "max_tokens": max_tokens, + "n": n, + "presence_penalty": presence_penalty, + "seed": seed, + "stop": stop, + "stream": stream, + "stream_options": stream_options, + "temperature": temperature, + "top_p": top_p, + "user": user, + }.items() + if v is not None + } + return await self.openai_client.completions.create(**params) # type: ignore async def initialize(self) -> None: pass