completion() for together (#324)

* completion() for together

* test fixes

* fix client building
This commit is contained in:
Dinesh Yeduguru 2024-10-25 14:21:12 -07:00 committed by GitHub
parent 8a74e400d6
commit 7ec79f3b9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 86 additions and 34 deletions

View file

@ -20,9 +20,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import TogetherImplConfig
@ -41,6 +44,7 @@ TOGETHER_SUPPORTED_MODELS = {
class TogetherInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
@ -49,7 +53,7 @@ class TogetherInferenceAdapter(
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
pass
async def shutdown(self) -> None:
pass
@ -63,7 +67,76 @@ class TogetherInferenceAdapter(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_client(self) -> Together:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
return Together(api_key=together_api_key)
async def _nonstream_completion(
self, request: CompletionRequest
) -> ChatCompletionResponse:
params = self._get_params_for_completion(request)
r = self._get_client().completions.create(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
def _build_options(
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
) -> dict:
options = get_sampling_options(sampling_params)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
return options
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
}
async def chat_completion(
self,
@ -77,18 +150,7 @@ class TogetherInferenceAdapter(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
client = Together(api_key=together_api_key)
request = ChatCompletionRequest(
model=model,
messages=messages,
@ -102,25 +164,25 @@ class TogetherInferenceAdapter(
)
if stream:
return self._stream_chat_completion(request, client)
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Together
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
r = self._get_client().completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Together
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = client.completions.create(**params)
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
@ -131,23 +193,11 @@ class TogetherInferenceAdapter(
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**options,
**self._build_options(request.sampling_params, request.response_format),
}
async def embeddings(