implement error handling, improve completion, tool calling and streaming

This commit is contained in:
Justin Lee 2025-03-21 16:52:32 -07:00 committed by Matt Clayton
parent fe575a0fdf
commit 05777dfb52

View file

@ -1,6 +1,10 @@
import asyncio import asyncio
from typing import AsyncIterator, AsyncGenerator, List, Literal, Optional, Union from typing import AsyncIterator, AsyncGenerator, List, Literal, Optional, Union
import lmstudio as lms import lmstudio as lms
import json
import re
import logging
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta from llama_stack.apis.common.content_types import InterleavedContent, TextDelta
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
@ -53,6 +57,84 @@ class LMStudioClient:
self.url = url self.url = url
self.sdk_client = lms.Client(self.url) self.sdk_client = lms.Client(self.url)
self.openai_client = OpenAI(base_url=f"http://{url}/v1", api_key="lmstudio") self.openai_client = OpenAI(base_url=f"http://{url}/v1", api_key="lmstudio")
# Standard error handling helper methods
def _log_error(self, error, context=""):
"""Centralized error logging method"""
logging.warning(f"Error in LMStudio {context}: {error}")
async def _create_fallback_chat_stream(self, error_message="I encountered an error processing your request."):
"""Create a standardized fallback stream for chat completions"""
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=error_message),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
)
)
async def _create_fallback_completion_stream(self, error_message="Error processing response"):
"""Create a standardized fallback stream for text completions"""
yield CompletionResponseStreamChunk(
delta=error_message,
)
def _create_fallback_chat_response(self, error_message="I encountered an error processing your request."):
"""Create a standardized fallback response for chat completions"""
return ChatCompletionResponse(
message=Message(
role="assistant",
content=error_message,
),
stop_reason=StopReason.end_of_message,
)
def _create_fallback_completion_response(self, error_message="Error processing response"):
"""Create a standardized fallback response for text completions"""
return CompletionResponse(
content=error_message,
stop_reason=StopReason.end_of_message,
)
def _handle_json_extraction(self, content, context="JSON extraction"):
"""Standardized method to extract valid JSON from potentially malformed content"""
try:
json_content = json.loads(content)
return json.dumps(json_content) # Re-serialize to ensure valid JSON
except json.JSONDecodeError as e:
self._log_error(e, f"{context} - Attempting to extract valid JSON")
json_patterns = [
r'(\{.*\})', # Match anything between curly braces
r'(\[.*\])', # Match anything between square brackets
r'```json\s*([\s\S]*?)\s*```', # Match content in JSON code blocks
r'```\s*([\s\S]*?)\s*```', # Match content in any code blocks
]
for pattern in json_patterns:
json_match = re.search(pattern, content, re.DOTALL)
if json_match:
valid_json = json_match.group(1)
try:
json_content = json.loads(valid_json)
return json.dumps(json_content) # Re-serialize to ensure valid JSON
except json.JSONDecodeError:
continue # Try the next pattern
# If we couldn't extract valid JSON, log a warning
self._log_error("Failed to extract valid JSON", context)
return None
async def check_if_model_present_in_lmstudio(self, provider_model_id): async def check_if_model_present_in_lmstudio(self, provider_model_id):
models = await asyncio.to_thread(self.sdk_client.list_downloaded_models) models = await asyncio.to_thread(self.sdk_client.list_downloaded_models)
@ -94,6 +176,7 @@ class LMStudioClient:
chat = self._convert_message_list_to_lmstudio_chat(messages) chat = self._convert_message_list_to_lmstudio_chat(messages)
config = self._get_completion_config_from_params(sampling_params) config = self._get_completion_config_from_params(sampling_params)
if stream: if stream:
async def stream_generator(): async def stream_generator():
prediction_stream = await asyncio.to_thread( prediction_stream = await asyncio.to_thread(
llm.respond_stream, llm.respond_stream,
@ -108,7 +191,19 @@ class LMStudioClient:
delta=TextDelta(text=""), delta=TextDelta(text=""),
) )
) )
async for chunk in self._async_iterate(prediction_stream):
# Convert to list to avoid StopIteration issues
try:
chunks = await asyncio.to_thread(list, prediction_stream)
except StopIteration:
# Handle StopIteration by returning an empty list
chunks = []
except Exception as e:
self._log_error(e, "converting chat stream to list")
chunks = []
# Yield each chunk
for chunk in chunks:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
@ -144,28 +239,48 @@ class LMStudioClient:
) -> Union[ ) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ]:
model_key = llm.get_info().model_key try:
request = ChatCompletionRequest( model_key = llm.get_info().model_key
model=model_key, request = ChatCompletionRequest(
messages=messages, model=model_key,
sampling_params=sampling_params, messages=messages,
response_format=json_schema, sampling_params=sampling_params,
tools=tools, response_format=json_schema,
tool_config=tool_config, tools=tools,
stream=stream, tool_config=tool_config,
) stream=stream,
rest_request = await self._convert_request_to_rest_call(request)
if stream:
stream = await self.openai_client.chat.completions.create(**rest_request)
return convert_openai_chat_completion_stream(
stream, enable_incremental_tool_calls=True
) )
response = await self.openai_client.chat.completions.create(**rest_request) rest_request = await self._convert_request_to_rest_call(request)
if response:
result = convert_openai_chat_completion_choice(response.choices[0]) if stream:
return result try:
else: stream = await self.openai_client.chat.completions.create(**rest_request)
return None return convert_openai_chat_completion_stream(
stream, enable_incremental_tool_calls=True
)
except Exception as e:
self._log_error(e, "streaming tool calling")
return self._create_fallback_chat_stream()
try:
response = await self.openai_client.chat.completions.create(**rest_request)
if response:
result = convert_openai_chat_completion_choice(response.choices[0])
return result
else:
# Handle empty response
self._log_error("Empty response from OpenAI API", "chat completion")
return self._create_fallback_chat_response()
except Exception as e:
self._log_error(e, "non-streaming tool calling")
return self._create_fallback_chat_response()
except Exception as e:
self._log_error(e, "_llm_respond_with_tools")
# Return a fallback response
if stream:
return self._create_fallback_chat_stream()
else:
return self._create_fallback_chat_response()
async def llm_respond( async def llm_respond(
self, self,
@ -208,30 +323,64 @@ class LMStudioClient:
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
config = self._get_completion_config_from_params(sampling_params) config = self._get_completion_config_from_params(sampling_params)
if stream: if stream:
async def stream_generator(): async def stream_generator():
prediction_stream = await asyncio.to_thread( try:
llm.complete_stream, prediction_stream = await asyncio.to_thread(
prompt=interleaved_content_as_str(content), llm.complete_stream,
config=config, prompt=interleaved_content_as_str(content),
response_format=json_schema, config=config,
) response_format=json_schema,
async for chunk in self._async_iterate(prediction_stream): )
try:
chunks = await asyncio.to_thread(list, prediction_stream)
except StopIteration:
# Handle StopIteration by returning an empty list
chunks = []
except Exception as e:
self._log_error(e, "converting completion stream to list")
chunks = []
for chunk in chunks:
yield CompletionResponseStreamChunk(
delta=chunk.content,
)
except Exception as e:
self._log_error(e, "streaming completion")
# Return a fallback response in case of error
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=chunk.content, delta="Error processing response",
) )
return stream_generator() return stream_generator()
else: else:
response = await asyncio.to_thread( try:
llm.complete, response = await asyncio.to_thread(
prompt=interleaved_content_as_str(content), llm.complete,
config=config, prompt=interleaved_content_as_str(content),
response_format=json_schema, config=config,
) response_format=json_schema,
return CompletionResponse( )
content=response.content,
stop_reason=self._get_stop_reason(response.stats.stop_reason), # If we have a JSON schema, ensure the response is valid JSON
) if json_schema is not None:
valid_json = self._handle_json_extraction(response.content, "completion response")
if valid_json:
return CompletionResponse(
content=valid_json, # Already serialized in _handle_json_extraction
stop_reason=self._get_stop_reason(response.stats.stop_reason),
)
# If we couldn't extract valid JSON, continue with the original content
return CompletionResponse(
content=response.content,
stop_reason=self._get_stop_reason(response.stats.stop_reason),
)
except Exception as e:
self._log_error(e, "LMStudio completion")
# Return a fallback response with an error message
return self._create_fallback_completion_response()
def _convert_message_list_to_lmstudio_chat( def _convert_message_list_to_lmstudio_chat(
self, messages: List[Message] self, messages: List[Message]
@ -305,18 +454,10 @@ class LMStudioClient:
return StopReason.end_of_turn return StopReason.end_of_turn
async def _async_iterate(self, iterable): async def _async_iterate(self, iterable):
iterator = iter(iterable) """Asynchronously iterate over a synchronous iterable."""
# Convert the synchronous iterable to a list first to avoid StopIteration issues
def safe_next(it): items = await asyncio.to_thread(list, iterable)
try: for item in items:
return (next(it), False)
except StopIteration:
return (None, True)
while True:
item, done = await asyncio.to_thread(safe_next, iterator)
if done:
break
yield item yield item
async def _convert_request_to_rest_call( async def _convert_request_to_rest_call(