llama-stack-mirror/llama_stack/providers/remote/inference/lmstudio/lmstudio.py
2025-04-25 14:47:23 -04:00

127 lines
4.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncIterator, List, Optional, Union
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
ResponseFormatType,
)
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.remote.inference.lmstudio._client import LMStudioClient
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media,
)
from .models import MODEL_ENTRIES
class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, url: str) -> None:
self.url = url
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
@property
def client(self) -> LMStudioClient:
return LMStudioClient(url=self.url)
async def initialize(self) -> None:
pass
async def register_model(self, model):
await self.register_helper.register_model(model)
return model
async def unregister_model(self, model_id):
pass
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
assert all(not content_has_media(content) for content in contents), (
"Media content not supported in embedding model"
)
model = await self.model_store.get_model(model_id)
embedding_model = await self.client.get_embedding_model(model.provider_model_id)
embeddings = await self.client.embed(embedding_model, contents)
return EmbeddingsResponse(embeddings=embeddings)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
model = await self.model_store.get_model(model_id)
llm = await self.client.get_llm(model.provider_model_id)
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider")
json_schema = response_format.json_schema if response_format else None
return await self.client.llm_respond(
llm=llm,
messages=messages,
sampling_params=sampling_params,
json_schema=json_schema,
stream=stream,
tool_config=tool_config,
tools=tools,
)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, # Skip this for now
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
model = await self.model_store.get_model(model_id)
llm = await self.client.get_llm(model.provider_model_id)
if content_has_media(content):
raise NotImplementedError("Media content not supported in LM Studio Provider")
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider")
json_schema = response_format.json_schema if response_format else None
return await self.client.llm_completion(llm, content, sampling_params, json_schema, stream)