Pass 1 for pre-commit fixes

This commit is contained in:
Matt Clayton 2025-04-27 15:24:37 -04:00
parent cfc6bdae68
commit 59e1c5f4a0
7 changed files with 119 additions and 109 deletions

View file

@ -33,7 +33,7 @@ The following models are available by default:
- `meta-llama-3.1-70b-instruct ` - `meta-llama-3.1-70b-instruct `
- `llama-3.2-1b-instruct ` - `llama-3.2-1b-instruct `
- `llama-3.2-3b-instruct ` - `llama-3.2-3b-instruct `
- `llama-3.2-70b-instruct ` - `llama-3.3-70b-instruct `
- `nomic-embed-text-v1.5 ` - `nomic-embed-text-v1.5 `
- `all-minilm-l6-v2 ` - `all-minilm-l6-v2 `

View file

@ -1,13 +1,20 @@
import asyncio # Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import AsyncIterator, AsyncGenerator, List, Literal, Optional, Union # All rights reserved.
import lmstudio as lms #
import json # This source code is licensed under the terms described in the LICENSE file in
import re # the root directory of this source tree.
import logging
import asyncio
import json
import logging
import re
from typing import Any, AsyncIterator, List, Literal, Optional, Union
import lmstudio as lms
from openai import AsyncOpenAI as OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta from llama_stack.apis.common.content_types import InterleavedContent, TextDelta
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
@ -16,15 +23,14 @@ from llama_stack.apis.inference.inference import (
CompletionMessage, CompletionMessage,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
GrammarResponseFormat,
GreedySamplingStrategy,
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
Message, Message,
ToolConfig,
ToolDefinition,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
SamplingParams, SamplingParams,
StopReason, StopReason,
ToolConfig,
ToolDefinition,
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
) )
@ -38,7 +44,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media, content_has_media,
interleaved_content_as_str, interleaved_content_as_str,
) )
from openai import AsyncOpenAI as OpenAI
LlmPredictionStopReason = Literal[ LlmPredictionStopReason = Literal[
"userStopped", "userStopped",
@ -63,7 +68,9 @@ class LMStudioClient:
"""Centralized error logging method""" """Centralized error logging method"""
logging.warning(f"Error in LMStudio {context}: {error}") logging.warning(f"Error in LMStudio {context}: {error}")
async def _create_fallback_chat_stream(self, error_message="I encountered an error processing your request."): async def _create_fallback_chat_stream(
self, error_message="I encountered an error processing your request."
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
"""Create a standardized fallback stream for chat completions""" """Create a standardized fallback stream for chat completions"""
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
@ -90,17 +97,19 @@ class LMStudioClient:
delta=error_message, delta=error_message,
) )
def _create_fallback_chat_response(self, error_message="I encountered an error processing your request."): def _create_fallback_chat_response(
self, error_message="I encountered an error processing your request."
) -> ChatCompletionResponse:
"""Create a standardized fallback response for chat completions""" """Create a standardized fallback response for chat completions"""
return ChatCompletionResponse( return ChatCompletionResponse(
message=Message( completion_message=CompletionMessage(
role="assistant", role="assistant",
content=error_message, content=error_message,
), stop_reason=StopReason.end_of_message,
stop_reason=StopReason.end_of_message, )
) )
def _create_fallback_completion_response(self, error_message="Error processing response"): def _create_fallback_completion_response(self, error_message="Error processing response") -> CompletionResponse:
"""Create a standardized fallback response for text completions""" """Create a standardized fallback response for text completions"""
return CompletionResponse( return CompletionResponse(
content=error_message, content=error_message,
@ -116,10 +125,10 @@ class LMStudioClient:
self._log_error(e, f"{context} - Attempting to extract valid JSON") self._log_error(e, f"{context} - Attempting to extract valid JSON")
json_patterns = [ json_patterns = [
r'(\{.*\})', # Match anything between curly braces r"(\{.*\})", # Match anything between curly braces
r'(\[.*\])', # Match anything between square brackets r"(\[.*\])", # Match anything between square brackets
r'```json\s*([\s\S]*?)\s*```', # Match content in JSON code blocks r"```json\s*([\s\S]*?)\s*```", # Match content in JSON code blocks
r'```\s*([\s\S]*?)\s*```', # Match content in any code blocks r"```\s*([\s\S]*?)\s*```", # Match content in any code blocks
] ]
for pattern in json_patterns: for pattern in json_patterns:
@ -148,14 +157,10 @@ class LMStudioClient:
return False return False
async def get_embedding_model(self, provider_model_id: str): async def get_embedding_model(self, provider_model_id: str):
model = await asyncio.to_thread( model = await asyncio.to_thread(self.sdk_client.embedding.model, provider_model_id)
self.sdk_client.embedding.model, provider_model_id
)
return model return model
async def embed( async def embed(self, embedding_model: lms.EmbeddingModel, contents: Union[str, List[str]]):
self, embedding_model: lms.EmbeddingModel, contents: Union[str, List[str]]
):
embeddings = await asyncio.to_thread(embedding_model.embed, contents) embeddings = await asyncio.to_thread(embedding_model.embed, contents)
return embeddings return embeddings
@ -170,14 +175,12 @@ class LMStudioClient:
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
json_schema: Optional[JsonSchemaResponseFormat] = None, json_schema: Optional[JsonSchemaResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> Union[ ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
chat = self._convert_message_list_to_lmstudio_chat(messages) chat = self._convert_message_list_to_lmstudio_chat(messages)
config = self._get_completion_config_from_params(sampling_params) config = self._get_completion_config_from_params(sampling_params)
if stream: if stream:
async def stream_generator(): async def stream_generator() -> AsyncIterator[ChatCompletionResponseStreamChunk]:
prediction_stream = await asyncio.to_thread( prediction_stream = await asyncio.to_thread(
llm.respond_stream, llm.respond_stream,
history=chat, history=chat,
@ -225,9 +228,7 @@ class LMStudioClient:
stream: Optional[bool] = False, stream: Optional[bool] = False,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
try: try:
model_key = llm.get_info().model_key model_key = llm.get_info().model_key
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -244,9 +245,7 @@ class LMStudioClient:
if stream: if stream:
try: try:
stream = await self.openai_client.chat.completions.create(**rest_request) stream = await self.openai_client.chat.completions.create(**rest_request)
return convert_openai_chat_completion_stream( return convert_openai_chat_completion_stream(stream, enable_incremental_tool_calls=True)
stream, enable_incremental_tool_calls=True
)
except Exception as e: except Exception as e:
self._log_error(e, "streaming tool calling") self._log_error(e, "streaming tool calling")
return self._create_fallback_chat_stream() return self._create_fallback_chat_stream()
@ -280,9 +279,7 @@ class LMStudioClient:
stream: Optional[bool] = False, stream: Optional[bool] = False,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
if tools is None or len(tools) == 0: if tools is None or len(tools) == 0:
return await self._llm_respond_non_tools( return await self._llm_respond_non_tools(
llm=llm, llm=llm,
@ -313,7 +310,7 @@ class LMStudioClient:
config = self._get_completion_config_from_params(sampling_params) config = self._get_completion_config_from_params(sampling_params)
if stream: if stream:
async def stream_generator(): async def stream_generator() -> AsyncIterator[CompletionResponseStreamChunk]:
try: try:
prediction_stream = await asyncio.to_thread( prediction_stream = await asyncio.to_thread(
llm.complete_stream, llm.complete_stream,
@ -361,15 +358,11 @@ class LMStudioClient:
# Return a fallback response with an error message # Return a fallback response with an error message
return self._create_fallback_completion_response() return self._create_fallback_completion_response()
def _convert_message_list_to_lmstudio_chat( def _convert_message_list_to_lmstudio_chat(self, messages: List[Message]) -> lms.Chat:
self, messages: List[Message]
) -> lms.Chat:
chat = lms.Chat() chat = lms.Chat()
for message in messages: for message in messages:
if content_has_media(message.content): if content_has_media(message.content):
raise NotImplementedError( raise NotImplementedError("Media content is not supported in LMStudio messages")
"Media content is not supported in LMStudio messages"
)
if message.role == "user": if message.role == "user":
chat.add_user_message(interleaved_content_as_str(message.content)) chat.add_user_message(interleaved_content_as_str(message.content))
elif message.role == "system": elif message.role == "system":
@ -380,9 +373,7 @@ class LMStudioClient:
raise ValueError(f"Unsupported message role: {message.role}") raise ValueError(f"Unsupported message role: {message.role}")
return chat return chat
def _convert_prediction_to_chat_response( def _convert_prediction_to_chat_response(self, result: lms.PredictionResult) -> ChatCompletionResponse:
self, result: lms.PredictionResult
) -> ChatCompletionResponse:
response = ChatCompletionResponse( response = ChatCompletionResponse(
completion_message=CompletionMessage( completion_message=CompletionMessage(
content=result.content, content=result.content,
@ -415,11 +406,7 @@ class LMStudioClient:
options.update( options.update(
{ {
"maxTokens": params.max_tokens if params.max_tokens != 0 else None, "maxTokens": params.max_tokens if params.max_tokens != 0 else None,
"repetitionPenalty": ( "repetitionPenalty": (params.repetition_penalty if params.repetition_penalty != 0 else None),
params.repetition_penalty
if params.repetition_penalty != 0
else None
),
} }
) )
return options return options
@ -449,32 +436,30 @@ class LMStudioClient:
break break
yield item yield item
async def _convert_request_to_rest_call( async def _convert_request_to_rest_call(self, request: ChatCompletionRequest) -> dict:
self, request: ChatCompletionRequest
) -> dict:
compatible_request = self._convert_sampling_params(request.sampling_params) compatible_request = self._convert_sampling_params(request.sampling_params)
compatible_request["model"] = request.model compatible_request["model"] = request.model
compatible_request["messages"] = [ compatible_request["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
await convert_message_to_openai_dict_new(m) for m in request.messages
]
if request.response_format: if request.response_format:
compatible_request["response_format"] = { if isinstance(request.response_format, JsonSchemaResponseFormat):
"type": "json_schema", compatible_request["response_format"] = {
"json_schema": request.response_format.json_schema, "type": "json_schema",
} "json_schema": request.response_format.json_schema,
}
elif isinstance(request.response_format, GrammarResponseFormat):
compatible_request["response_format"] = {
"type": "grammar",
"bnf": request.response_format.bnf,
}
if request.tools is not None: if request.tools is not None:
compatible_request["tools"] = [ compatible_request["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
convert_tooldef_to_openai_tool(tool) for tool in request.tools
]
compatible_request["logprobs"] = False compatible_request["logprobs"] = False
compatible_request["stream"] = request.stream compatible_request["stream"] = request.stream
compatible_request["extra_headers"] = { compatible_request["extra_headers"] = {b"User-Agent": b"llama-stack: lmstudio-inference-adapter"}
b"User-Agent": b"llama-stack: lmstudio-inference-adapter"
}
return compatible_request return compatible_request
def _convert_sampling_params(self, sampling_params: Optional[SamplingParams]) -> dict: def _convert_sampling_params(self, sampling_params: Optional[SamplingParams]) -> dict:
params = {} params: dict[str, Any] = {}
if sampling_params is None: if sampling_params is None:
return params return params

View file

@ -14,7 +14,9 @@ from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
GrammarResponseFormat,
Inference, Inference,
JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -29,7 +31,6 @@ from llama_stack.apis.inference.inference import (
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
ResponseFormatType,
) )
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.remote.inference.lmstudio._client import LMStudioClient from llama_stack.providers.remote.inference.lmstudio._client import LMStudioClient
@ -50,6 +51,18 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
def client(self) -> LMStudioClient: def client(self) -> LMStudioClient:
return LMStudioClient(url=self.url) return LMStudioClient(url=self.url)
async def batch_chat_completion(self, *args, **kwargs):
raise NotImplementedError("Batch chat completion not supported by LM Studio Provider")
async def batch_completion(self, *args, **kwargs):
raise NotImplementedError("Batch completion not supported by LM Studio Provider")
async def openai_chat_completion(self, *args, **kwargs):
raise NotImplementedError("OpenAI chat completion not supported by LM Studio Provider")
async def openai_completion(self, *args, **kwargs):
raise NotImplementedError("OpenAI completion not supported by LM Studio Provider")
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -71,9 +84,12 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (
"Media content not supported in embedding model" "Media content not supported in embedding model"
) )
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embedding_model = await self.client.get_embedding_model(model.provider_model_id) embedding_model = await self.client.get_embedding_model(model.provider_model_id)
embeddings = await self.client.embed(embedding_model, contents) string_contents = [item.text if hasattr(item, "text") else str(item) for item in contents]
embeddings = await self.client.embed(embedding_model, string_contents)
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
async def chat_completion( async def chat_completion(
@ -81,26 +97,31 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = None, # Default value changed from ToolChoice.auto to None
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[
Union[JsonSchemaResponseFormat, GrammarResponseFormat]
] = None, # Moved and type changed
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
llm = await self.client.get_llm(model.provider_model_id) llm = await self.client.get_llm(model.provider_model_id)
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value: json_schema_format = response_format if isinstance(response_format, JsonSchemaResponseFormat) else None
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider") if response_format is not None and not isinstance(response_format, JsonSchemaResponseFormat):
json_schema = response_format.json_schema if response_format else None raise ValueError(
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
)
return await self.client.llm_respond( return await self.client.llm_respond(
llm=llm, llm=llm,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
json_schema=json_schema, json_schema=json_schema_format,
stream=stream, stream=stream,
tool_config=tool_config, tool_config=tool_config,
tools=tools, tools=tools,
@ -115,13 +136,16 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, # Skip this for now logprobs: Optional[LogProbConfig] = None, # Skip this for now
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
llm = await self.client.get_llm(model.provider_model_id) llm = await self.client.get_llm(model.provider_model_id)
if content_has_media(content): if content_has_media(content):
raise NotImplementedError("Media content not supported in LM Studio Provider") raise NotImplementedError("Media content not supported in LM Studio Provider")
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value: if not isinstance(response_format, JsonSchemaResponseFormat):
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider") raise ValueError(
json_schema = response_format.json_schema if response_format else None f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
)
return await self.client.llm_completion(llm, content, sampling_params, json_schema, stream) return await self.client.llm_completion(llm, content, sampling_params, response_format, stream)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry, ProviderModelEntry,
) )
@ -63,9 +63,7 @@ MODEL_ENTRIES = [
}, },
), ),
ProviderModelEntry( ProviderModelEntry(
model_id="all-MiniLM-L6-v2",
provider_model_id="all-minilm-l6-v2", provider_model_id="all-minilm-l6-v2",
provider_id="lmstudio",
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={ metadata={
"embedding_dimension": 384, "embedding_dimension": 384,

View file

@ -351,10 +351,12 @@
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",
"emoji",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"langdetect",
"lmstudio", "lmstudio",
"matplotlib", "matplotlib",
"nltk", "nltk",
@ -367,6 +369,7 @@
"psycopg2-binary", "psycopg2-binary",
"pymongo", "pymongo",
"pypdf", "pypdf",
"pythainlp",
"redis", "redis",
"requests", "requests",
"scikit-learn", "scikit-learn",
@ -374,6 +377,7 @@
"sentencepiece", "sentencepiece",
"tqdm", "tqdm",
"transformers", "transformers",
"tree_sitter",
"uvicorn" "uvicorn"
], ],
"meta-reference-gpu": [ "meta-reference-gpu": [

View file

@ -24,7 +24,6 @@ distribution_spec:
telemetry: telemetry:
- inline::meta-reference - inline::meta-reference
tool_runtime: tool_runtime:
- remote::brave-search
- remote::tavily-search - remote::tavily-search
- inline::code-interpreter - inline::code-interpreter
- inline::rag-runtime - inline::rag-runtime

View file

@ -75,7 +75,7 @@ providers:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack} service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite} sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/lmstudio/trace_store.db} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/lmstudio/trace_store.db}
tool_runtime: tool_runtime:
@ -125,9 +125,9 @@ models:
provider_model_id: llama-3.2-3b-instruct provider_model_id: llama-3.2-3b-instruct
model_type: llm model_type: llm
- metadata: {} - metadata: {}
model_id: llama-3.2-70b-instruct model_id: llama-3.3-70b-instruct
provider_id: lmstudio provider_id: lmstudio
provider_model_id: llama-3.2-70b-instruct provider_model_id: llama-3.3-70b-instruct
model_type: llm model_type: llm
- metadata: - metadata:
embedding_dimension: 768 embedding_dimension: 768