import asyncio from typing import AsyncIterator, AsyncGenerator, List, Literal, Optional, Union import lmstudio as lms import json import re import logging from llama_stack.apis.common.content_types import InterleavedContent, TextDelta from llama_stack.apis.inference.inference import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseEvent, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, CompletionResponse, CompletionResponseStreamChunk, JsonSchemaResponseFormat, Message, ToolConfig, ToolDefinition, ) from llama_stack.models.llama.datatypes import ( GreedySamplingStrategy, SamplingParams, StopReason, TopKSamplingStrategy, TopPSamplingStrategy, ) from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict_new, convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, convert_tooldef_to_openai_tool, ) from llama_stack.providers.utils.inference.prompt_adapter import ( content_has_media, interleaved_content_as_str, ) from openai import AsyncOpenAI as OpenAI LlmPredictionStopReason = Literal[ "userStopped", "modelUnloaded", "failed", "eosFound", "stopStringFound", "toolCalls", "maxPredictedTokensReached", "contextLengthReached", ] class LMStudioClient: def __init__(self, url: str) -> None: self.url = url self.sdk_client = lms.Client(self.url) self.openai_client = OpenAI(base_url=f"http://{url}/v1", api_key="lmstudio") # Standard error handling helper methods def _log_error(self, error, context=""): """Centralized error logging method""" logging.warning(f"Error in LMStudio {context}: {error}") async def _create_fallback_chat_stream(self, error_message="I encountered an error processing your request."): """Create a standardized fallback stream for chat completions""" yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, delta=TextDelta(text=""), ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=TextDelta(text=error_message), ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, delta=TextDelta(text=""), ) ) async def _create_fallback_completion_stream(self, error_message="Error processing response"): """Create a standardized fallback stream for text completions""" yield CompletionResponseStreamChunk( delta=error_message, ) def _create_fallback_chat_response(self, error_message="I encountered an error processing your request."): """Create a standardized fallback response for chat completions""" return ChatCompletionResponse( message=Message( role="assistant", content=error_message, ), stop_reason=StopReason.end_of_message, ) def _create_fallback_completion_response(self, error_message="Error processing response"): """Create a standardized fallback response for text completions""" return CompletionResponse( content=error_message, stop_reason=StopReason.end_of_message, ) def _handle_json_extraction(self, content, context="JSON extraction"): """Standardized method to extract valid JSON from potentially malformed content""" try: json_content = json.loads(content) return json.dumps(json_content) # Re-serialize to ensure valid JSON except json.JSONDecodeError as e: self._log_error(e, f"{context} - Attempting to extract valid JSON") json_patterns = [ r'(\{.*\})', # Match anything between curly braces r'(\[.*\])', # Match anything between square brackets r'```json\s*([\s\S]*?)\s*```', # Match content in JSON code blocks r'```\s*([\s\S]*?)\s*```', # Match content in any code blocks ] for pattern in json_patterns: json_match = re.search(pattern, content, re.DOTALL) if json_match: valid_json = json_match.group(1) try: json_content = json.loads(valid_json) return json.dumps(json_content) # Re-serialize to ensure valid JSON except json.JSONDecodeError: continue # Try the next pattern # If we couldn't extract valid JSON, log a warning self._log_error("Failed to extract valid JSON", context) return None async def check_if_model_present_in_lmstudio(self, provider_model_id): models = await asyncio.to_thread(self.sdk_client.list_downloaded_models) model_ids = [m.model_key for m in models] if provider_model_id in model_ids: return True model_ids = [id.split("/")[-1] for id in model_ids] if provider_model_id in model_ids: return True return False async def get_embedding_model(self, provider_model_id: str): model = await asyncio.to_thread( self.sdk_client.embedding.model, provider_model_id ) return model async def embed( self, embedding_model: lms.EmbeddingModel, contents: Union[str, List[str]] ): embeddings = await asyncio.to_thread(embedding_model.embed, contents) return embeddings async def get_llm(self, provider_model_id: str) -> lms.LLM: model = await asyncio.to_thread(self.sdk_client.llm.model, provider_model_id) return model async def _llm_respond_non_tools( self, llm: lms.LLM, messages: List[Message], sampling_params: Optional[SamplingParams] = None, json_schema: Optional[JsonSchemaResponseFormat] = None, stream: Optional[bool] = False, ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: chat = self._convert_message_list_to_lmstudio_chat(messages) config = self._get_completion_config_from_params(sampling_params) if stream: async def stream_generator(): prediction_stream = await asyncio.to_thread( llm.respond_stream, history=chat, config=config, response_format=json_schema, ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, delta=TextDelta(text=""), ) ) # Convert to list to avoid StopIteration issues try: chunks = await asyncio.to_thread(list, prediction_stream) except StopIteration: # Handle StopIteration by returning an empty list chunks = [] except Exception as e: self._log_error(e, "converting chat stream to list") chunks = [] # Yield each chunk for chunk in chunks: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=TextDelta(text=chunk.content), ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, delta=TextDelta(text=""), ) ) return stream_generator() else: response = await asyncio.to_thread( llm.respond, history=chat, config=config, response_format=json_schema, ) return self._convert_prediction_to_chat_response(response) async def _llm_respond_with_tools( self, llm: lms.LLM, messages: List[Message], sampling_params: Optional[SamplingParams] = None, json_schema: Optional[JsonSchemaResponseFormat] = None, stream: Optional[bool] = False, tools: Optional[List[ToolDefinition]] = None, tool_config: Optional[ToolConfig] = None, ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: try: model_key = llm.get_info().model_key request = ChatCompletionRequest( model=model_key, messages=messages, sampling_params=sampling_params, response_format=json_schema, tools=tools, tool_config=tool_config, stream=stream, ) rest_request = await self._convert_request_to_rest_call(request) if stream: try: stream = await self.openai_client.chat.completions.create(**rest_request) return convert_openai_chat_completion_stream( stream, enable_incremental_tool_calls=True ) except Exception as e: self._log_error(e, "streaming tool calling") return self._create_fallback_chat_stream() try: response = await self.openai_client.chat.completions.create(**rest_request) if response: result = convert_openai_chat_completion_choice(response.choices[0]) return result else: # Handle empty response self._log_error("Empty response from OpenAI API", "chat completion") return self._create_fallback_chat_response() except Exception as e: self._log_error(e, "non-streaming tool calling") return self._create_fallback_chat_response() except Exception as e: self._log_error(e, "_llm_respond_with_tools") # Return a fallback response if stream: return self._create_fallback_chat_stream() else: return self._create_fallback_chat_response() async def llm_respond( self, llm: lms.LLM, messages: List[Message], sampling_params: Optional[SamplingParams] = None, json_schema: Optional[JsonSchemaResponseFormat] = None, stream: Optional[bool] = False, tools: Optional[List[ToolDefinition]] = None, tool_config: Optional[ToolConfig] = None, ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: if tools is None or len(tools) == 0: return await self._llm_respond_non_tools( llm=llm, messages=messages, sampling_params=sampling_params, json_schema=json_schema, stream=stream, ) else: return await self._llm_respond_with_tools( llm=llm, messages=messages, sampling_params=sampling_params, json_schema=json_schema, stream=stream, tools=tools, tool_config=tool_config, ) async def llm_completion( self, llm: lms.LLM, content: InterleavedContent, sampling_params: Optional[SamplingParams] = None, json_schema: Optional[JsonSchemaResponseFormat] = None, stream: Optional[bool] = False, ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: config = self._get_completion_config_from_params(sampling_params) if stream: async def stream_generator(): try: prediction_stream = await asyncio.to_thread( llm.complete_stream, prompt=interleaved_content_as_str(content), config=config, response_format=json_schema, ) try: chunks = await asyncio.to_thread(list, prediction_stream) except StopIteration: # Handle StopIteration by returning an empty list chunks = [] except Exception as e: self._log_error(e, "converting completion stream to list") chunks = [] for chunk in chunks: yield CompletionResponseStreamChunk( delta=chunk.content, ) except Exception as e: self._log_error(e, "streaming completion") # Return a fallback response in case of error yield CompletionResponseStreamChunk( delta="Error processing response", ) return stream_generator() else: try: response = await asyncio.to_thread( llm.complete, prompt=interleaved_content_as_str(content), config=config, response_format=json_schema, ) # If we have a JSON schema, ensure the response is valid JSON if json_schema is not None: valid_json = self._handle_json_extraction(response.content, "completion response") if valid_json: return CompletionResponse( content=valid_json, # Already serialized in _handle_json_extraction stop_reason=self._get_stop_reason(response.stats.stop_reason), ) # If we couldn't extract valid JSON, continue with the original content return CompletionResponse( content=response.content, stop_reason=self._get_stop_reason(response.stats.stop_reason), ) except Exception as e: self._log_error(e, "LMStudio completion") # Return a fallback response with an error message return self._create_fallback_completion_response() def _convert_message_list_to_lmstudio_chat( self, messages: List[Message] ) -> lms.Chat: chat = lms.Chat() for message in messages: if content_has_media(message.content): raise NotImplementedError( "Media content is not supported in LMStudio messages" ) if message.role == "user": chat.add_user_message(interleaved_content_as_str(message.content)) elif message.role == "system": chat.add_system_prompt(interleaved_content_as_str(message.content)) elif message.role == "assistant": chat.add_assistant_response(interleaved_content_as_str(message.content)) else: raise ValueError(f"Unsupported message role: {message.role}") return chat def _convert_prediction_to_chat_response( self, result: lms.PredictionResult ) -> ChatCompletionResponse: response = ChatCompletionResponse( completion_message=CompletionMessage( content=result.content, stop_reason=self._get_stop_reason(result.stats.stop_reason), tool_calls=None, ) ) return response def _get_completion_config_from_params( self, params: Optional[SamplingParams] = None, ) -> lms.LlmPredictionConfigDict: options = lms.LlmPredictionConfigDict() if params is None: return options if isinstance(params.strategy, GreedySamplingStrategy): options.update({"temperature": 0.0}) elif isinstance(params.strategy, TopPSamplingStrategy): options.update( { "temperature": params.strategy.temperature, "topPSampling": params.strategy.top_p, } ) elif isinstance(params.strategy, TopKSamplingStrategy): options.update({"topKSampling": params.strategy.top_k}) else: raise ValueError(f"Unsupported sampling strategy: {params.strategy}") options.update( { "maxTokens": params.max_tokens if params.max_tokens != 0 else None, "repetitionPenalty": ( params.repetition_penalty if params.repetition_penalty != 0 else None ), } ) return options def _get_stop_reason(self, stop_reason: LlmPredictionStopReason) -> StopReason: if stop_reason == "eosFound": return StopReason.end_of_message elif stop_reason == "maxPredictedTokensReached": return StopReason.out_of_tokens else: return StopReason.end_of_turn async def _async_iterate(self, iterable): """Asynchronously iterate over a synchronous iterable.""" # Convert the synchronous iterable to a list first to avoid StopIteration issues items = await asyncio.to_thread(list, iterable) for item in items: yield item async def _convert_request_to_rest_call( self, request: ChatCompletionRequest ) -> dict: compatible_request = self._convert_sampling_params(request.sampling_params) compatible_request["model"] = request.model compatible_request["messages"] = [ await convert_message_to_openai_dict_new(m) for m in request.messages ] if request.response_format: compatible_request["response_format"] = { "type": "json_schema", "json_schema": request.response_format.json_schema, } if request.tools is not None: compatible_request["tools"] = [ convert_tooldef_to_openai_tool(tool) for tool in request.tools ] compatible_request["logprobs"] = False compatible_request["stream"] = request.stream compatible_request["extra_headers"] = { b"User-Agent": b"llama-stack: lmstudio-inference-adapter" } return compatible_request def _convert_sampling_params(self, sampling_params: Optional[SamplingParams]) -> dict: params = {} if sampling_params is None: return params params["frequency_penalty"] = sampling_params.repetition_penalty if sampling_params.max_tokens: params["max_completion_tokens"] = sampling_params.max_tokens if isinstance(sampling_params.strategy, TopPSamplingStrategy): params["top_p"] = sampling_params.strategy.top_p if isinstance(sampling_params.strategy, TopKSamplingStrategy): params["extra_body"]["top_k"] = sampling_params.strategy.top_k if isinstance(sampling_params.strategy, GreedySamplingStrategy): params["temperature"] = 0.0 return params