Pass 1 for pre-commit fixes

This commit is contained in:
Matt Clayton 2025-04-27 15:24:37 -04:00
parent cfc6bdae68
commit 59e1c5f4a0
7 changed files with 119 additions and 109 deletions

View file

@ -14,7 +14,9 @@ from llama_stack.apis.inference import (
ChatCompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
@ -29,7 +31,6 @@ from llama_stack.apis.inference.inference import (
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
ResponseFormatType,
)
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.remote.inference.lmstudio._client import LMStudioClient
@ -50,6 +51,18 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
def client(self) -> LMStudioClient:
return LMStudioClient(url=self.url)
async def batch_chat_completion(self, *args, **kwargs):
raise NotImplementedError("Batch chat completion not supported by LM Studio Provider")
async def batch_completion(self, *args, **kwargs):
raise NotImplementedError("Batch completion not supported by LM Studio Provider")
async def openai_chat_completion(self, *args, **kwargs):
raise NotImplementedError("OpenAI chat completion not supported by LM Studio Provider")
async def openai_completion(self, *args, **kwargs):
raise NotImplementedError("OpenAI completion not supported by LM Studio Provider")
async def initialize(self) -> None:
pass
@ -71,9 +84,12 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
assert all(not content_has_media(content) for content in contents), (
"Media content not supported in embedding model"
)
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id)
embedding_model = await self.client.get_embedding_model(model.provider_model_id)
embeddings = await self.client.embed(embedding_model, contents)
string_contents = [item.text if hasattr(item, "text") else str(item) for item in contents]
embeddings = await self.client.embed(embedding_model, string_contents)
return EmbeddingsResponse(embeddings=embeddings)
async def chat_completion(
@ -81,26 +97,31 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_choice: Optional[ToolChoice] = None, # Default value changed from ToolChoice.auto to None
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[
Union[JsonSchemaResponseFormat, GrammarResponseFormat]
] = None, # Moved and type changed
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id)
llm = await self.client.get_llm(model.provider_model_id)
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider")
json_schema = response_format.json_schema if response_format else None
json_schema_format = response_format if isinstance(response_format, JsonSchemaResponseFormat) else None
if response_format is not None and not isinstance(response_format, JsonSchemaResponseFormat):
raise ValueError(
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
)
return await self.client.llm_respond(
llm=llm,
messages=messages,
sampling_params=sampling_params,
json_schema=json_schema,
json_schema=json_schema_format,
stream=stream,
tool_config=tool_config,
tools=tools,
@ -115,13 +136,16 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, # Skip this for now
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if self.model_store is None:
raise ValueError("ModelStore is not initialized")
model = await self.model_store.get_model(model_id)
llm = await self.client.get_llm(model.provider_model_id)
if content_has_media(content):
raise NotImplementedError("Media content not supported in LM Studio Provider")
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider")
json_schema = response_format.json_schema if response_format else None
if not isinstance(response_format, JsonSchemaResponseFormat):
raise ValueError(
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
)
return await self.client.llm_completion(llm, content, sampling_params, json_schema, stream)
return await self.client.llm_completion(llm, content, sampling_params, response_format, stream)