completion() for together (#324)

* completion() for together

* test fixes

* fix client building
This commit is contained in:
Dinesh Yeduguru 2024-10-25 14:21:12 -07:00 committed by GitHub
parent 8a74e400d6
commit 7ec79f3b9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 86 additions and 34 deletions

View file

@ -20,9 +20,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt,
) )
from .config import TogetherImplConfig from .config import TogetherImplConfig
@ -41,6 +44,7 @@ TOGETHER_SUPPORTED_MODELS = {
class TogetherInferenceAdapter( class TogetherInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData ModelRegistryHelper, Inference, NeedsRequestProviderData
): ):
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
@ -49,7 +53,7 @@ class TogetherInferenceAdapter(
self.formatter = ChatFormat(Tokenizer.get_instance()) self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None: async def initialize(self) -> None:
return pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
@ -63,7 +67,76 @@ class TogetherInferenceAdapter(
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_client(self) -> Together:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
return Together(api_key=together_api_key)
async def _nonstream_completion(
self, request: CompletionRequest
) -> ChatCompletionResponse:
params = self._get_params_for_completion(request)
r = self._get_client().completions.create(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
def _build_options(
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
) -> dict:
options = get_sampling_options(sampling_params)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
return options
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
}
async def chat_completion( async def chat_completion(
self, self,
@ -77,18 +150,7 @@ class TogetherInferenceAdapter(
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
client = Together(api_key=together_api_key)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
messages=messages, messages=messages,
@ -102,25 +164,25 @@ class TogetherInferenceAdapter(
) )
if stream: if stream:
return self._stream_chat_completion(request, client) return self._stream_chat_completion(request)
else: else:
return await self._nonstream_chat_completion(request, client) return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Together self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = self._get_params(request) params = self._get_params(request)
r = client.completions.create(**params) r = self._get_client().completions.create(**params)
return process_chat_completion_response(r, self.formatter) return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Together self, request: ChatCompletionRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
params = self._get_params(request) params = self._get_params(request)
# if we shift to TogetherAsyncClient, we won't need this wrapper # if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator(): async def _to_async_generator():
s = client.completions.create(**params) s = self._get_client().completions.create(**params)
for chunk in s: for chunk in s:
yield chunk yield chunk
@ -131,23 +193,11 @@ class TogetherInferenceAdapter(
yield chunk yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict: def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
return { return {
"model": self.map_to_provider_model(request.model), "model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter), "prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream, "stream": request.stream,
**options, **self._build_options(request.sampling_params, request.response_format),
} }
async def embeddings( async def embeddings(

View file

@ -138,11 +138,12 @@ async def test_completion(inference_settings):
"meta-reference", "meta-reference",
"remote::ollama", "remote::ollama",
"remote::tgi", "remote::tgi",
"remote::together",
): ):
pytest.skip("Other inference providers don't support completion() yet") pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion( response = await inference_impl.completion(
content="Roses are red,", content="Micheael Jordan is born in ",
stream=False, stream=False,
model=params["model"], model=params["model"],
sampling_params=SamplingParams( sampling_params=SamplingParams(
@ -151,7 +152,7 @@ async def test_completion(inference_settings):
) )
assert isinstance(response, CompletionResponse) assert isinstance(response, CompletionResponse)
assert "violets are blue" in response.content assert "1963" in response.content
chunks = [ chunks = [
r r
@ -180,6 +181,7 @@ async def test_completions_structured_output(inference_settings):
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::tgi", "remote::tgi",
"remote::together",
): ):
pytest.skip( pytest.skip(
"Other inference providers don't support structured output in completions yet" "Other inference providers don't support structured output in completions yet"