From 05777dfb52471f7b3a05b7706c4bf9e710dcfdec Mon Sep 17 00:00:00 2001 From: Justin Lee Date: Fri, 21 Mar 2025 16:52:32 -0700 Subject: [PATCH] implement error handling, improve completion, tool calling and streaming --- .../remote/inference/lmstudio/_client.py | 245 ++++++++++++++---- 1 file changed, 193 insertions(+), 52 deletions(-) diff --git a/llama_stack/providers/remote/inference/lmstudio/_client.py b/llama_stack/providers/remote/inference/lmstudio/_client.py index 5359585b0..f03cb7bc0 100644 --- a/llama_stack/providers/remote/inference/lmstudio/_client.py +++ b/llama_stack/providers/remote/inference/lmstudio/_client.py @@ -1,6 +1,10 @@ import asyncio from typing import AsyncIterator, AsyncGenerator, List, Literal, Optional, Union import lmstudio as lms +import json +import re +import logging + from llama_stack.apis.common.content_types import InterleavedContent, TextDelta from llama_stack.apis.inference.inference import ( @@ -53,6 +57,84 @@ class LMStudioClient: self.url = url self.sdk_client = lms.Client(self.url) self.openai_client = OpenAI(base_url=f"http://{url}/v1", api_key="lmstudio") + + # Standard error handling helper methods + def _log_error(self, error, context=""): + """Centralized error logging method""" + logging.warning(f"Error in LMStudio {context}: {error}") + + async def _create_fallback_chat_stream(self, error_message="I encountered an error processing your request."): + """Create a standardized fallback stream for chat completions""" + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), + ) + ) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=TextDelta(text=error_message), + ) + ) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=""), + ) + ) + + async def _create_fallback_completion_stream(self, error_message="Error processing response"): + """Create a standardized fallback stream for text completions""" + yield CompletionResponseStreamChunk( + delta=error_message, + ) + + def _create_fallback_chat_response(self, error_message="I encountered an error processing your request."): + """Create a standardized fallback response for chat completions""" + return ChatCompletionResponse( + message=Message( + role="assistant", + content=error_message, + ), + stop_reason=StopReason.end_of_message, + ) + + def _create_fallback_completion_response(self, error_message="Error processing response"): + """Create a standardized fallback response for text completions""" + return CompletionResponse( + content=error_message, + stop_reason=StopReason.end_of_message, + ) + + def _handle_json_extraction(self, content, context="JSON extraction"): + """Standardized method to extract valid JSON from potentially malformed content""" + try: + json_content = json.loads(content) + return json.dumps(json_content) # Re-serialize to ensure valid JSON + except json.JSONDecodeError as e: + self._log_error(e, f"{context} - Attempting to extract valid JSON") + + json_patterns = [ + r'(\{.*\})', # Match anything between curly braces + r'(\[.*\])', # Match anything between square brackets + r'```json\s*([\s\S]*?)\s*```', # Match content in JSON code blocks + r'```\s*([\s\S]*?)\s*```', # Match content in any code blocks + ] + + for pattern in json_patterns: + json_match = re.search(pattern, content, re.DOTALL) + if json_match: + valid_json = json_match.group(1) + try: + json_content = json.loads(valid_json) + return json.dumps(json_content) # Re-serialize to ensure valid JSON + except json.JSONDecodeError: + continue # Try the next pattern + + # If we couldn't extract valid JSON, log a warning + self._log_error("Failed to extract valid JSON", context) + return None async def check_if_model_present_in_lmstudio(self, provider_model_id): models = await asyncio.to_thread(self.sdk_client.list_downloaded_models) @@ -94,6 +176,7 @@ class LMStudioClient: chat = self._convert_message_list_to_lmstudio_chat(messages) config = self._get_completion_config_from_params(sampling_params) if stream: + async def stream_generator(): prediction_stream = await asyncio.to_thread( llm.respond_stream, @@ -108,7 +191,19 @@ class LMStudioClient: delta=TextDelta(text=""), ) ) - async for chunk in self._async_iterate(prediction_stream): + + # Convert to list to avoid StopIteration issues + try: + chunks = await asyncio.to_thread(list, prediction_stream) + except StopIteration: + # Handle StopIteration by returning an empty list + chunks = [] + except Exception as e: + self._log_error(e, "converting chat stream to list") + chunks = [] + + # Yield each chunk + for chunk in chunks: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -144,28 +239,48 @@ class LMStudioClient: ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: - model_key = llm.get_info().model_key - request = ChatCompletionRequest( - model=model_key, - messages=messages, - sampling_params=sampling_params, - response_format=json_schema, - tools=tools, - tool_config=tool_config, - stream=stream, - ) - rest_request = await self._convert_request_to_rest_call(request) - if stream: - stream = await self.openai_client.chat.completions.create(**rest_request) - return convert_openai_chat_completion_stream( - stream, enable_incremental_tool_calls=True + try: + model_key = llm.get_info().model_key + request = ChatCompletionRequest( + model=model_key, + messages=messages, + sampling_params=sampling_params, + response_format=json_schema, + tools=tools, + tool_config=tool_config, + stream=stream, ) - response = await self.openai_client.chat.completions.create(**rest_request) - if response: - result = convert_openai_chat_completion_choice(response.choices[0]) - return result - else: - return None + rest_request = await self._convert_request_to_rest_call(request) + + if stream: + try: + stream = await self.openai_client.chat.completions.create(**rest_request) + return convert_openai_chat_completion_stream( + stream, enable_incremental_tool_calls=True + ) + except Exception as e: + self._log_error(e, "streaming tool calling") + return self._create_fallback_chat_stream() + + try: + response = await self.openai_client.chat.completions.create(**rest_request) + if response: + result = convert_openai_chat_completion_choice(response.choices[0]) + return result + else: + # Handle empty response + self._log_error("Empty response from OpenAI API", "chat completion") + return self._create_fallback_chat_response() + except Exception as e: + self._log_error(e, "non-streaming tool calling") + return self._create_fallback_chat_response() + except Exception as e: + self._log_error(e, "_llm_respond_with_tools") + # Return a fallback response + if stream: + return self._create_fallback_chat_stream() + else: + return self._create_fallback_chat_response() async def llm_respond( self, @@ -208,30 +323,64 @@ class LMStudioClient: ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: config = self._get_completion_config_from_params(sampling_params) if stream: + async def stream_generator(): - prediction_stream = await asyncio.to_thread( - llm.complete_stream, - prompt=interleaved_content_as_str(content), - config=config, - response_format=json_schema, - ) - async for chunk in self._async_iterate(prediction_stream): + try: + prediction_stream = await asyncio.to_thread( + llm.complete_stream, + prompt=interleaved_content_as_str(content), + config=config, + response_format=json_schema, + ) + + try: + chunks = await asyncio.to_thread(list, prediction_stream) + except StopIteration: + # Handle StopIteration by returning an empty list + chunks = [] + except Exception as e: + self._log_error(e, "converting completion stream to list") + chunks = [] + + for chunk in chunks: + yield CompletionResponseStreamChunk( + delta=chunk.content, + ) + except Exception as e: + self._log_error(e, "streaming completion") + # Return a fallback response in case of error yield CompletionResponseStreamChunk( - delta=chunk.content, + delta="Error processing response", ) return stream_generator() else: - response = await asyncio.to_thread( - llm.complete, - prompt=interleaved_content_as_str(content), - config=config, - response_format=json_schema, - ) - return CompletionResponse( - content=response.content, - stop_reason=self._get_stop_reason(response.stats.stop_reason), - ) + try: + response = await asyncio.to_thread( + llm.complete, + prompt=interleaved_content_as_str(content), + config=config, + response_format=json_schema, + ) + + # If we have a JSON schema, ensure the response is valid JSON + if json_schema is not None: + valid_json = self._handle_json_extraction(response.content, "completion response") + if valid_json: + return CompletionResponse( + content=valid_json, # Already serialized in _handle_json_extraction + stop_reason=self._get_stop_reason(response.stats.stop_reason), + ) + # If we couldn't extract valid JSON, continue with the original content + + return CompletionResponse( + content=response.content, + stop_reason=self._get_stop_reason(response.stats.stop_reason), + ) + except Exception as e: + self._log_error(e, "LMStudio completion") + # Return a fallback response with an error message + return self._create_fallback_completion_response() def _convert_message_list_to_lmstudio_chat( self, messages: List[Message] @@ -305,18 +454,10 @@ class LMStudioClient: return StopReason.end_of_turn async def _async_iterate(self, iterable): - iterator = iter(iterable) - - def safe_next(it): - try: - return (next(it), False) - except StopIteration: - return (None, True) - - while True: - item, done = await asyncio.to_thread(safe_next, iterator) - if done: - break + """Asynchronously iterate over a synchronous iterable.""" + # Convert the synchronous iterable to a list first to avoid StopIteration issues + items = await asyncio.to_thread(list, iterable) + for item in items: yield item async def _convert_request_to_rest_call(