mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
implement error handling, improve completion, tool calling and streaming
This commit is contained in:
parent
fe575a0fdf
commit
05777dfb52
1 changed files with 193 additions and 52 deletions
|
@ -1,6 +1,10 @@
|
|||
import asyncio
|
||||
from typing import AsyncIterator, AsyncGenerator, List, Literal, Optional, Union
|
||||
import lmstudio as lms
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta
|
||||
from llama_stack.apis.inference.inference import (
|
||||
|
@ -53,6 +57,84 @@ class LMStudioClient:
|
|||
self.url = url
|
||||
self.sdk_client = lms.Client(self.url)
|
||||
self.openai_client = OpenAI(base_url=f"http://{url}/v1", api_key="lmstudio")
|
||||
|
||||
# Standard error handling helper methods
|
||||
def _log_error(self, error, context=""):
|
||||
"""Centralized error logging method"""
|
||||
logging.warning(f"Error in LMStudio {context}: {error}")
|
||||
|
||||
async def _create_fallback_chat_stream(self, error_message="I encountered an error processing your request."):
|
||||
"""Create a standardized fallback stream for chat completions"""
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text=error_message),
|
||||
)
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
|
||||
async def _create_fallback_completion_stream(self, error_message="Error processing response"):
|
||||
"""Create a standardized fallback stream for text completions"""
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=error_message,
|
||||
)
|
||||
|
||||
def _create_fallback_chat_response(self, error_message="I encountered an error processing your request."):
|
||||
"""Create a standardized fallback response for chat completions"""
|
||||
return ChatCompletionResponse(
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=error_message,
|
||||
),
|
||||
stop_reason=StopReason.end_of_message,
|
||||
)
|
||||
|
||||
def _create_fallback_completion_response(self, error_message="Error processing response"):
|
||||
"""Create a standardized fallback response for text completions"""
|
||||
return CompletionResponse(
|
||||
content=error_message,
|
||||
stop_reason=StopReason.end_of_message,
|
||||
)
|
||||
|
||||
def _handle_json_extraction(self, content, context="JSON extraction"):
|
||||
"""Standardized method to extract valid JSON from potentially malformed content"""
|
||||
try:
|
||||
json_content = json.loads(content)
|
||||
return json.dumps(json_content) # Re-serialize to ensure valid JSON
|
||||
except json.JSONDecodeError as e:
|
||||
self._log_error(e, f"{context} - Attempting to extract valid JSON")
|
||||
|
||||
json_patterns = [
|
||||
r'(\{.*\})', # Match anything between curly braces
|
||||
r'(\[.*\])', # Match anything between square brackets
|
||||
r'```json\s*([\s\S]*?)\s*```', # Match content in JSON code blocks
|
||||
r'```\s*([\s\S]*?)\s*```', # Match content in any code blocks
|
||||
]
|
||||
|
||||
for pattern in json_patterns:
|
||||
json_match = re.search(pattern, content, re.DOTALL)
|
||||
if json_match:
|
||||
valid_json = json_match.group(1)
|
||||
try:
|
||||
json_content = json.loads(valid_json)
|
||||
return json.dumps(json_content) # Re-serialize to ensure valid JSON
|
||||
except json.JSONDecodeError:
|
||||
continue # Try the next pattern
|
||||
|
||||
# If we couldn't extract valid JSON, log a warning
|
||||
self._log_error("Failed to extract valid JSON", context)
|
||||
return None
|
||||
|
||||
async def check_if_model_present_in_lmstudio(self, provider_model_id):
|
||||
models = await asyncio.to_thread(self.sdk_client.list_downloaded_models)
|
||||
|
@ -94,6 +176,7 @@ class LMStudioClient:
|
|||
chat = self._convert_message_list_to_lmstudio_chat(messages)
|
||||
config = self._get_completion_config_from_params(sampling_params)
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
prediction_stream = await asyncio.to_thread(
|
||||
llm.respond_stream,
|
||||
|
@ -108,7 +191,19 @@ class LMStudioClient:
|
|||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
async for chunk in self._async_iterate(prediction_stream):
|
||||
|
||||
# Convert to list to avoid StopIteration issues
|
||||
try:
|
||||
chunks = await asyncio.to_thread(list, prediction_stream)
|
||||
except StopIteration:
|
||||
# Handle StopIteration by returning an empty list
|
||||
chunks = []
|
||||
except Exception as e:
|
||||
self._log_error(e, "converting chat stream to list")
|
||||
chunks = []
|
||||
|
||||
# Yield each chunk
|
||||
for chunk in chunks:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
@ -144,28 +239,48 @@ class LMStudioClient:
|
|||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
model_key = llm.get_info().model_key
|
||||
request = ChatCompletionRequest(
|
||||
model=model_key,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=json_schema,
|
||||
tools=tools,
|
||||
tool_config=tool_config,
|
||||
stream=stream,
|
||||
)
|
||||
rest_request = await self._convert_request_to_rest_call(request)
|
||||
if stream:
|
||||
stream = await self.openai_client.chat.completions.create(**rest_request)
|
||||
return convert_openai_chat_completion_stream(
|
||||
stream, enable_incremental_tool_calls=True
|
||||
try:
|
||||
model_key = llm.get_info().model_key
|
||||
request = ChatCompletionRequest(
|
||||
model=model_key,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=json_schema,
|
||||
tools=tools,
|
||||
tool_config=tool_config,
|
||||
stream=stream,
|
||||
)
|
||||
response = await self.openai_client.chat.completions.create(**rest_request)
|
||||
if response:
|
||||
result = convert_openai_chat_completion_choice(response.choices[0])
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
rest_request = await self._convert_request_to_rest_call(request)
|
||||
|
||||
if stream:
|
||||
try:
|
||||
stream = await self.openai_client.chat.completions.create(**rest_request)
|
||||
return convert_openai_chat_completion_stream(
|
||||
stream, enable_incremental_tool_calls=True
|
||||
)
|
||||
except Exception as e:
|
||||
self._log_error(e, "streaming tool calling")
|
||||
return self._create_fallback_chat_stream()
|
||||
|
||||
try:
|
||||
response = await self.openai_client.chat.completions.create(**rest_request)
|
||||
if response:
|
||||
result = convert_openai_chat_completion_choice(response.choices[0])
|
||||
return result
|
||||
else:
|
||||
# Handle empty response
|
||||
self._log_error("Empty response from OpenAI API", "chat completion")
|
||||
return self._create_fallback_chat_response()
|
||||
except Exception as e:
|
||||
self._log_error(e, "non-streaming tool calling")
|
||||
return self._create_fallback_chat_response()
|
||||
except Exception as e:
|
||||
self._log_error(e, "_llm_respond_with_tools")
|
||||
# Return a fallback response
|
||||
if stream:
|
||||
return self._create_fallback_chat_stream()
|
||||
else:
|
||||
return self._create_fallback_chat_response()
|
||||
|
||||
async def llm_respond(
|
||||
self,
|
||||
|
@ -208,30 +323,64 @@ class LMStudioClient:
|
|||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
config = self._get_completion_config_from_params(sampling_params)
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
prediction_stream = await asyncio.to_thread(
|
||||
llm.complete_stream,
|
||||
prompt=interleaved_content_as_str(content),
|
||||
config=config,
|
||||
response_format=json_schema,
|
||||
)
|
||||
async for chunk in self._async_iterate(prediction_stream):
|
||||
try:
|
||||
prediction_stream = await asyncio.to_thread(
|
||||
llm.complete_stream,
|
||||
prompt=interleaved_content_as_str(content),
|
||||
config=config,
|
||||
response_format=json_schema,
|
||||
)
|
||||
|
||||
try:
|
||||
chunks = await asyncio.to_thread(list, prediction_stream)
|
||||
except StopIteration:
|
||||
# Handle StopIteration by returning an empty list
|
||||
chunks = []
|
||||
except Exception as e:
|
||||
self._log_error(e, "converting completion stream to list")
|
||||
chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=chunk.content,
|
||||
)
|
||||
except Exception as e:
|
||||
self._log_error(e, "streaming completion")
|
||||
# Return a fallback response in case of error
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=chunk.content,
|
||||
delta="Error processing response",
|
||||
)
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
response = await asyncio.to_thread(
|
||||
llm.complete,
|
||||
prompt=interleaved_content_as_str(content),
|
||||
config=config,
|
||||
response_format=json_schema,
|
||||
)
|
||||
return CompletionResponse(
|
||||
content=response.content,
|
||||
stop_reason=self._get_stop_reason(response.stats.stop_reason),
|
||||
)
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
llm.complete,
|
||||
prompt=interleaved_content_as_str(content),
|
||||
config=config,
|
||||
response_format=json_schema,
|
||||
)
|
||||
|
||||
# If we have a JSON schema, ensure the response is valid JSON
|
||||
if json_schema is not None:
|
||||
valid_json = self._handle_json_extraction(response.content, "completion response")
|
||||
if valid_json:
|
||||
return CompletionResponse(
|
||||
content=valid_json, # Already serialized in _handle_json_extraction
|
||||
stop_reason=self._get_stop_reason(response.stats.stop_reason),
|
||||
)
|
||||
# If we couldn't extract valid JSON, continue with the original content
|
||||
|
||||
return CompletionResponse(
|
||||
content=response.content,
|
||||
stop_reason=self._get_stop_reason(response.stats.stop_reason),
|
||||
)
|
||||
except Exception as e:
|
||||
self._log_error(e, "LMStudio completion")
|
||||
# Return a fallback response with an error message
|
||||
return self._create_fallback_completion_response()
|
||||
|
||||
def _convert_message_list_to_lmstudio_chat(
|
||||
self, messages: List[Message]
|
||||
|
@ -305,18 +454,10 @@ class LMStudioClient:
|
|||
return StopReason.end_of_turn
|
||||
|
||||
async def _async_iterate(self, iterable):
|
||||
iterator = iter(iterable)
|
||||
|
||||
def safe_next(it):
|
||||
try:
|
||||
return (next(it), False)
|
||||
except StopIteration:
|
||||
return (None, True)
|
||||
|
||||
while True:
|
||||
item, done = await asyncio.to_thread(safe_next, iterator)
|
||||
if done:
|
||||
break
|
||||
"""Asynchronously iterate over a synchronous iterable."""
|
||||
# Convert the synchronous iterable to a list first to avoid StopIteration issues
|
||||
items = await asyncio.to_thread(list, iterable)
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
async def _convert_request_to_rest_call(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue