@ -47,8 +47,19 @@ class SVEndpointHandler:
"""
result : Dict [ str , Any ] = { }
try :
text_result = response . text . strip ( ) . split ( " \n " ) [ - 1 ]
result = { " data " : json . loads ( " " . join ( text_result . split ( " data: " ) [ 1 : ] ) ) }
lines_result = response . text . strip ( ) . split ( " \n " )
text_result = lines_result [ - 1 ]
if response . status_code == 200 and json . loads ( text_result ) . get ( " error " ) :
completion = " "
for line in lines_result [ : - 1 ] :
completion + = json . loads ( line ) [ " result " ] [ " responses " ] [ 0 ] [
" stream_token "
]
text_result = lines_result [ - 2 ]
result = json . loads ( text_result )
result [ " result " ] [ " responses " ] [ 0 ] [ " completion " ] = completion
else :
result = json . loads ( text_result )
except Exception as e :
result [ " detail " ] = str ( e )
if " status_code " not in result :
@ -58,25 +69,19 @@ class SVEndpointHandler:
@staticmethod
def _process_streaming_response (
response : requests . Response ,
) - > Generator [ GenerationChunk , None , None ] :
) - > Generator [ Dict , None , None ] :
""" Process the streaming response """
try :
import sseclient
except ImportError :
raise ImportError (
" could not import sseclient library "
" Please install it with `pip install sseclient-py`. "
)
client = sseclient . SSEClient ( response )
close_conn = False
for event in client . events ( ) :
if event . event == " error_event " :
close_conn = True
text = json . dumps ( { " event " : event . event , " data " : event . data } )
chunk = GenerationChunk ( text = text )
yield chunk
if close_conn :
client . close ( )
for line in response . iter_lines ( ) :
chunk = json . loads ( line )
if " status_code " not in chunk :
chunk [ " status_code " ] = response . status_code
if chunk [ " status_code " ] == 200 and chunk . get ( " error " ) :
chunk [ " result " ] = { " responses " : [ { " stream_token " : " " } ] }
return chunk
yield chunk
except Exception as e :
raise RuntimeError ( f " Error processing streaming response: { e } " )
def _get_full_url ( self ) - > str :
"""
@ -105,25 +110,21 @@ class SVEndpointHandler:
: returns : Prediction results
: rtype : dict
"""
if isinstance ( input , str ) :
input = [ input ]
parsed_input = [ ]
for element in input :
parsed_element = {
" conversation_id " : " sambaverse-conversation-id " ,
" messages " : [
{
" message_id " : 0 ,
" role " : " user " ,
" content " : element ,
}
] ,
}
parsed_input . append ( json . dumps ( parsed_element ) )
parsed_element = {
" conversation_id " : " sambaverse-conversation-id " ,
" messages " : [
{
" message_id " : 0 ,
" role " : " user " ,
" content " : input ,
}
] ,
}
parsed_input = json . dumps ( parsed_element )
if params :
data = { " in put s" : parsed_input , " params " : json . loads ( params ) }
data = { " instance " : parsed_input , " params " : json . loads ( params ) }
else :
data = { " in put s" : parsed_input }
data = { " instance " : parsed_input }
response = self . http_session . post (
self . _get_full_url ( ) ,
headers = {
@ -141,7 +142,7 @@ class SVEndpointHandler:
sambaverse_model_name : Optional [ str ] ,
input : Union [ List [ str ] , str ] ,
params : Optional [ str ] = " " ,
) - > Iterator [ GenerationChunk ] :
) - > Iterator [ Dict ] :
"""
NLP predict using inline input string .
@ -153,25 +154,21 @@ class SVEndpointHandler:
: returns : Prediction results
: rtype : dict
"""
if isinstance ( input , str ) :
input = [ input ]
parsed_input = [ ]
for element in input :
parsed_element = {
" conversation_id " : " sambaverse-conversation-id " ,
" messages " : [
{
" message_id " : 0 ,
" role " : " user " ,
" content " : element ,
}
] ,
}
parsed_input . append ( json . dumps ( parsed_element ) )
parsed_element = {
" conversation_id " : " sambaverse-conversation-id " ,
" messages " : [
{
" message_id " : 0 ,
" role " : " user " ,
" content " : input ,
}
] ,
}
parsed_input = json . dumps ( parsed_element )
if params :
data = { " in put s" : parsed_input , " params " : json . loads ( params ) }
data = { " instance " : parsed_input , " params " : json . loads ( params ) }
else :
data = { " in put s" : parsed_input }
data = { " instance " : parsed_input }
# Streaming output
response = self . http_session . post (
self . _get_full_url ( ) ,
@ -213,7 +210,7 @@ class Sambaverse(LLM):
" max_tokens_to_generate " : 100 ,
" temperature " : 0.7 ,
" top_p " : 1.0 ,
" repetition_penalty " : 1 ,
" repetition_penalty " : 1.0 ,
" top_k " : 50 ,
} ,
)
@ -279,13 +276,17 @@ class Sambaverse(LLM):
The tuning parameters as a JSON string .
"""
_model_kwargs = self . model_kwargs or { }
_stop_sequences = _model_kwargs . get ( " stop_sequences " , [ ] )
_stop_sequences = stop or _stop_sequences
_model_kwargs [ " stop_sequences " ] = " , " . join ( f ' " { x } " ' for x in _stop_sequences )
_kwarg_stop_sequences = _model_kwargs . get ( " stop_sequences " , [ ] )
_stop_sequences = stop or _kwarg_stop_sequences
if not _kwarg_stop_sequences :
_model_kwargs [ " stop_sequences " ] = " , " . join (
f ' " { x } " ' for x in _stop_sequences
)
tuning_params_dict = {
k : { " type " : type ( v ) . __name__ , " value " : str ( v ) }
for k , v in ( _model_kwargs . items ( ) )
}
_model_kwargs [ " stop_sequences " ] = _kwarg_stop_sequences
tuning_params = json . dumps ( tuning_params_dict )
return tuning_params
@ -313,14 +314,17 @@ class Sambaverse(LLM):
self . sambaverse_api_key , self . sambaverse_model_name , prompt , tuning_params
)
if response [ " status_code " ] != 200 :
optional_details = response [ " details " ]
optional_message = response [ " message " ]
optional_code = response [ " error " ] . get ( " code " )
optional_details = response [ " error " ] . get ( " details " )
optional_message = response [ " error " ] . get ( " message " )
raise ValueError (
f " Sambanova /complete call failed with status code "
f " { response [ ' status_code ' ] } . Details: { optional_details } "
f " { response [ ' status_code ' ] } . Message: { optional_message } "
f " { response [ ' status_code ' ] } . "
f " Message: { optional_message } "
f " Details: { optional_details } "
f " Code: { optional_code } "
)
return response [ " data " ] [ " completion " ]
return response [ " result" ] [ " responses " ] [ 0 ] [ " completion " ]
def _handle_completion_requests (
self , prompt : Union [ List [ str ] , str ] , stop : Optional [ List [ str ] ]
@ -359,7 +363,20 @@ class Sambaverse(LLM):
for chunk in sdk . nlp_predict_stream (
self . sambaverse_api_key , self . sambaverse_model_name , prompt , tuning_params
) :
yield chunk
if chunk [ " status_code " ] != 200 :
optional_code = chunk [ " error " ] . get ( " code " )
optional_details = chunk [ " error " ] . get ( " details " )
optional_message = chunk [ " error " ] . get ( " message " )
raise ValueError (
f " Sambanova /complete call failed with status code "
f " { chunk [ ' status_code ' ] } . "
f " Message: { optional_message } "
f " Details: { optional_details } "
f " Code: { optional_code } "
)
text = chunk [ " result " ] [ " responses " ] [ 0 ] [ " stream_token " ]
generated_chunk = GenerationChunk ( text = text )
yield generated_chunk
def _stream (
self ,