community: sambaverse api update (#21816)

- **Description:** fix sambaverse integration to make it compatible with
sambaverse API update / minor changes in docs
pull/21450/head^2
Jorge Piedrahita Ortiz 2 weeks ago committed by GitHub
parent 7976fb1663
commit 700b1c7212
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -22,7 +22,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"**Sambaverse** allows you to interact with multiple open-source models. You can view the list of available models and interact with them in the [playground](https://sambaverse.sambanova.ai/playground).\n **Please note that Sambaverse's free offering is performance-limited.** Companies that are ready to evaluate the production tokens-per-second performance, volume throughput, and 10x lower total cost of ownership (TCO) of SambaNova should [contact us](https://sambaverse.sambanova.ai/contact-us) for a non-limited evaluation instance."
"**Sambaverse** allows you to interact with multiple open-source models. You can view the list of available models and interact with them in the [playground](https://sambaverse.sambanova.ai/playground).\n",
" **Please note that Sambaverse's free offering is performance-limited.** Companies that are ready to evaluate the production tokens-per-second performance, volume throughput, and 10x lower total cost of ownership (TCO) of SambaNova should [contact us](https://sambaverse.sambanova.ai/contact-us) for a non-limited evaluation instance."
]
},
{
@ -88,9 +89,10 @@
" \"temperature\": 0.01,\n",
" \"process_prompt\": True,\n",
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
" # \"repetition_penalty\": {\"type\": \"float\", \"value\": \"1\"},\n",
" # \"top_k\": {\"type\": \"int\", \"value\": \"50\"},\n",
" # \"top_p\": {\"type\": \"float\", \"value\": \"1\"}\n",
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
" # \"repetition_penalty\": 1.0,\n",
" # \"top_k\": 50,\n",
" # \"top_p\": 1.0\n",
" },\n",
")\n",
"\n",
@ -177,10 +179,10 @@
" \"do_sample\": True,\n",
" \"max_tokens_to_generate\": 1000,\n",
" \"temperature\": 0.01,\n",
" # \"repetition_penalty\": {\"type\": \"float\", \"value\": \"1\"},\n",
" # \"top_k\": {\"type\": \"int\", \"value\": \"50\"},\n",
" # \"top_logprobs\": {\"type\": \"int\", \"value\": \"0\"},\n",
" # \"top_p\": {\"type\": \"float\", \"value\": \"1\"}\n",
" # \"repetition_penalty\": 1.0,\n",
" # \"top_k\": 50,\n",
" # \"top_logprobs\": 0,\n",
" # \"top_p\": 1.0\n",
" },\n",
")\n",
"\n",

@ -47,8 +47,19 @@ class SVEndpointHandler:
"""
result: Dict[str, Any] = {}
try:
text_result = response.text.strip().split("\n")[-1]
result = {"data": json.loads("".join(text_result.split("data: ")[1:]))}
lines_result = response.text.strip().split("\n")
text_result = lines_result[-1]
if response.status_code == 200 and json.loads(text_result).get("error"):
completion = ""
for line in lines_result[:-1]:
completion += json.loads(line)["result"]["responses"][0][
"stream_token"
]
text_result = lines_result[-2]
result = json.loads(text_result)
result["result"]["responses"][0]["completion"] = completion
else:
result = json.loads(text_result)
except Exception as e:
result["detail"] = str(e)
if "status_code" not in result:
@ -58,25 +69,19 @@ class SVEndpointHandler:
@staticmethod
def _process_streaming_response(
response: requests.Response,
) -> Generator[GenerationChunk, None, None]:
) -> Generator[Dict, None, None]:
"""Process the streaming response"""
try:
import sseclient
except ImportError:
raise ImportError(
"could not import sseclient library"
"Please install it with `pip install sseclient-py`."
)
client = sseclient.SSEClient(response)
close_conn = False
for event in client.events():
if event.event == "error_event":
close_conn = True
text = json.dumps({"event": event.event, "data": event.data})
chunk = GenerationChunk(text=text)
yield chunk
if close_conn:
client.close()
for line in response.iter_lines():
chunk = json.loads(line)
if "status_code" not in chunk:
chunk["status_code"] = response.status_code
if chunk["status_code"] == 200 and chunk.get("error"):
chunk["result"] = {"responses": [{"stream_token": ""}]}
return chunk
yield chunk
except Exception as e:
raise RuntimeError(f"Error processing streaming response: {e}")
def _get_full_url(self) -> str:
"""
@ -105,25 +110,21 @@ class SVEndpointHandler:
:returns: Prediction results
:rtype: dict
"""
if isinstance(input, str):
input = [input]
parsed_input = []
for element in input:
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": element,
}
],
}
parsed_input.append(json.dumps(parsed_element))
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": input,
}
],
}
parsed_input = json.dumps(parsed_element)
if params:
data = {"inputs": parsed_input, "params": json.loads(params)}
data = {"instance": parsed_input, "params": json.loads(params)}
else:
data = {"inputs": parsed_input}
data = {"instance": parsed_input}
response = self.http_session.post(
self._get_full_url(),
headers={
@ -141,7 +142,7 @@ class SVEndpointHandler:
sambaverse_model_name: Optional[str],
input: Union[List[str], str],
params: Optional[str] = "",
) -> Iterator[GenerationChunk]:
) -> Iterator[Dict]:
"""
NLP predict using inline input string.
@ -153,25 +154,21 @@ class SVEndpointHandler:
:returns: Prediction results
:rtype: dict
"""
if isinstance(input, str):
input = [input]
parsed_input = []
for element in input:
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": element,
}
],
}
parsed_input.append(json.dumps(parsed_element))
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": input,
}
],
}
parsed_input = json.dumps(parsed_element)
if params:
data = {"inputs": parsed_input, "params": json.loads(params)}
data = {"instance": parsed_input, "params": json.loads(params)}
else:
data = {"inputs": parsed_input}
data = {"instance": parsed_input}
# Streaming output
response = self.http_session.post(
self._get_full_url(),
@ -213,7 +210,7 @@ class Sambaverse(LLM):
"max_tokens_to_generate": 100,
"temperature": 0.7,
"top_p": 1.0,
"repetition_penalty": 1,
"repetition_penalty": 1.0,
"top_k": 50,
},
)
@ -279,13 +276,17 @@ class Sambaverse(LLM):
The tuning parameters as a JSON string.
"""
_model_kwargs = self.model_kwargs or {}
_stop_sequences = _model_kwargs.get("stop_sequences", [])
_stop_sequences = stop or _stop_sequences
_model_kwargs["stop_sequences"] = ",".join(f'"{x}"' for x in _stop_sequences)
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
_stop_sequences = stop or _kwarg_stop_sequences
if not _kwarg_stop_sequences:
_model_kwargs["stop_sequences"] = ",".join(
f'"{x}"' for x in _stop_sequences
)
tuning_params_dict = {
k: {"type": type(v).__name__, "value": str(v)}
for k, v in (_model_kwargs.items())
}
_model_kwargs["stop_sequences"] = _kwarg_stop_sequences
tuning_params = json.dumps(tuning_params_dict)
return tuning_params
@ -313,14 +314,17 @@ class Sambaverse(LLM):
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
)
if response["status_code"] != 200:
optional_details = response["details"]
optional_message = response["message"]
optional_code = response["error"].get("code")
optional_details = response["error"].get("details")
optional_message = response["error"].get("message")
raise ValueError(
f"Sambanova /complete call failed with status code "
f"{response['status_code']}. Details: {optional_details}"
f"{response['status_code']}. Message: {optional_message}"
f"{response['status_code']}."
f"Message: {optional_message}"
f"Details: {optional_details}"
f"Code: {optional_code}"
)
return response["data"]["completion"]
return response["result"]["responses"][0]["completion"]
def _handle_completion_requests(
self, prompt: Union[List[str], str], stop: Optional[List[str]]
@ -359,7 +363,20 @@ class Sambaverse(LLM):
for chunk in sdk.nlp_predict_stream(
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
):
yield chunk
if chunk["status_code"] != 200:
optional_code = chunk["error"].get("code")
optional_details = chunk["error"].get("details")
optional_message = chunk["error"].get("message")
raise ValueError(
f"Sambanova /complete call failed with status code "
f"{chunk['status_code']}."
f"Message: {optional_message}"
f"Details: {optional_details}"
f"Code: {optional_code}"
)
text = chunk["result"]["responses"][0]["stream_token"]
generated_chunk = GenerationChunk(text=text)
yield generated_chunk
def _stream(
self,

Loading…
Cancel
Save