Refactor the messages marshaling and unmarshaling

pull/25/head
Vasile Popescu 4 years ago committed by Elis Popescu
parent 3c77e059b3
commit 366edcd23e

@ -1,7 +1,6 @@
package main package main
import ( import (
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -12,7 +11,7 @@ import (
"sync/atomic" "sync/atomic"
"syscall" "syscall"
ttyServer "github.com/elisescu/tty-share/server" "github.com/elisescu/tty-share/server"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/moby/term" "github.com/moby/term"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -20,7 +19,7 @@ import (
type ttyShareClient struct { type ttyShareClient struct {
url string url string
connection *websocket.Conn wsConn *websocket.Conn
detachKeys string detachKeys string
wcChan chan os.Signal wcChan chan os.Signal
writeFlag uint32 // used with atomic writeFlag uint32 // used with atomic
@ -36,7 +35,7 @@ type ttyShareClient struct {
func newTtyShareClient(url string, detachKeys string) *ttyShareClient { func newTtyShareClient(url string, detachKeys string) *ttyShareClient {
return &ttyShareClient{ return &ttyShareClient{
url: url, url: url,
connection: nil, wsConn: nil,
detachKeys: detachKeys, detachKeys: detachKeys,
wcChan: make(chan os.Signal, 1), wcChan: make(chan os.Signal, 1),
writeFlag: 1, writeFlag: 1,
@ -47,15 +46,6 @@ func clearScreen() {
fmt.Fprintf(os.Stdout, "\033[H\033[2J") fmt.Fprintf(os.Stdout, "\033[H\033[2J")
} }
type wsTextWriter struct {
conn *websocket.Conn
}
func (w *wsTextWriter) Write(data []byte) (n int, err error) {
err = w.conn.WriteMessage(websocket.TextMessage, data)
return len(data), err
}
type keyListener struct { type keyListener struct {
wrappedReader io.Reader wrappedReader io.Reader
} }
@ -120,7 +110,7 @@ func (c *ttyShareClient) Run() (err error) {
log.Debugf("Built the WS URL from the headers: %s", wsURL) log.Debugf("Built the WS URL from the headers: %s", wsURL)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) c.wsConn, _, err = websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil { if err != nil {
return return
} }
@ -133,15 +123,12 @@ func (c *ttyShareClient) Run() (err error) {
state, err := term.MakeRaw(os.Stdin.Fd()) state, err := term.MakeRaw(os.Stdin.Fd())
defer term.RestoreTerminal(os.Stdin.Fd(), state) defer term.RestoreTerminal(os.Stdin.Fd(), state)
c.connection = conn
clearScreen() clearScreen()
// start monitoring the size of the terminal
signal.Notify(c.wcChan, syscall.SIGWINCH)
defer signal.Stop(c.wcChan)
monitorWinChanges := func() { monitorWinChanges := func() {
// start monitoring the size of the terminal
signal.Notify(c.wcChan, syscall.SIGWINCH)
for { for {
select { select {
case <-c.wcChan: case <-c.wcChan:
@ -152,58 +139,48 @@ func (c *ttyShareClient) Run() (err error) {
} }
} }
readLoop := func() { protoWS := server.NewTTYProtocolWS(c.wsConn)
for {
var msg ttyServer.MsgAll
_, r, err := conn.NextReader()
if err != nil {
log.Debugf("Connection closed")
return
}
err = json.NewDecoder(r).Decode(&msg)
if err != nil {
log.Errorf("Cannot read JSON: %s", err.Error())
}
switch msg.Type { readLoop := func() {
case ttyServer.MsgIDWrite:
var msgWrite ttyServer.MsgTTYWrite
err := json.Unmarshal(msg.Data, &msgWrite)
if err != nil { var err error
log.Errorf("Cannot read JSON: %s", err.Error()) for {
} err = protoWS.ReadAndHandle(
// onWrite
func(data []byte) {
if atomic.LoadUint32(&c.writeFlag) != 0 {
os.Stdout.Write(data)
}
},
// onWindowSize
func(cols, rows int) {
c.winSizesMutex.Lock()
c.winSizes.remoteW = uint16(cols)
c.winSizes.remoteH = uint16(rows)
c.winSizesMutex.Unlock()
c.updateThisWinSize()
c.updateAndDecideStdoutMuted()
},
)
if atomic.LoadUint32(&c.writeFlag) != 0 { if err != nil {
os.Stdout.Write(msgWrite.Data) log.Errorf("Error parsing remote message: %s", err.Error())
} if err == io.EOF {
case ttyServer.MsgIDWinSize: // Remote WS connection closed
var msgRemoteWinSize ttyServer.MsgTTYWinSize return
err := json.Unmarshal(msg.Data, &msgRemoteWinSize)
if err != nil {
continue
} }
c.winSizesMutex.Lock()
c.winSizes.remoteW = uint16(msgRemoteWinSize.Cols)
c.winSizes.remoteH = uint16(msgRemoteWinSize.Rows)
c.winSizesMutex.Unlock()
c.updateThisWinSize()
c.updateAndDecideStdoutMuted()
} }
} }
} }
writeLoop := func() { writeLoop := func() {
ww := &wsTextWriter{
conn: conn,
}
kl := &keyListener{ kl := &keyListener{
wrappedReader: term.NewEscapeProxy(os.Stdin, detachBytes), wrappedReader: term.NewEscapeProxy(os.Stdin, detachBytes),
} }
_, err := io.Copy(ttyServer.NewTTYProtocolWriter(ww), kl) _, err := io.Copy(protoWS, kl)
if err != nil { if err != nil {
log.Debugf("Connection closed") log.Debugf("Connection closed: %s", err.Error())
c.Stop() c.Stop()
return return
} }
@ -218,5 +195,6 @@ func (c *ttyShareClient) Run() (err error) {
} }
func (c *ttyShareClient) Stop() { func (c *ttyShareClient) Stop() {
c.connection.Close() c.wsConn.Close()
signal.Stop(c.wcChan)
} }

@ -147,7 +147,7 @@ func (server *TTYServer) handleWebsocket(w http.ResponseWriter, r *http.Request)
} }
server.newClientCB(conn.RemoteAddr().String()) server.newClientCB(conn.RemoteAddr().String())
server.session.HandleWSConnection(newWSConnection(conn)) server.session.HandleWSConnection(conn)
} }
func panicIfErr(err error) { func panicIfErr(err error) {

@ -2,30 +2,31 @@ package server
import ( import (
"container/list" "container/list"
"encoding/json"
"io" "io"
"sync" "sync"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type ttyShareSession struct { type ttyShareSession struct {
mainRWLock sync.RWMutex mainRWLock sync.RWMutex
ttyReceiverConnections *list.List ttyProtoConnections *list.List
isAlive bool isAlive bool
lastWindowSizeMsg MsgTTYWinSize lastWindowSizeMsg MsgTTYWinSize
ttyWriter io.Writer ttyWriter io.Writer
} }
func newTTYShareSession(ttyWriter io.Writer) *ttyShareSession { // quick and dirty locked writer
type lockedWriter struct {
ttyShareSession := &ttyShareSession{ writer io.Writer
ttyReceiverConnections: list.New(), lock sync.Mutex
ttyWriter: ttyWriter, }
}
return ttyShareSession func (wl *lockedWriter) Write(data []byte) (int, error) {
wl.lock.Lock()
defer wl.lock.Unlock()
return wl.writer.Write(data)
} }
func copyList(l *list.List) *list.List { func copyList(l *list.List) *list.List {
@ -36,123 +37,89 @@ func copyList(l *list.List) *list.List {
return newList return newList
} }
func (session *ttyShareSession) WindowSize(cols, rows int) (err error) { func newTTYShareSession(ttyWriter io.Writer) *ttyShareSession {
msg := MsgTTYWinSize{
Cols: cols, ttyShareSession := &ttyShareSession{
Rows: rows, ttyProtoConnections: list.New(),
ttyWriter: ttyWriter,
} }
return ttyShareSession
}
func (session *ttyShareSession) WindowSize(cols, rows int) error {
session.mainRWLock.Lock() session.mainRWLock.Lock()
session.lastWindowSizeMsg = msg session.lastWindowSizeMsg = MsgTTYWinSize{Cols: cols, Rows: rows}
session.mainRWLock.Unlock() session.mainRWLock.Unlock()
data, _ := MarshalMsg(msg) session.forEachReceiverLock(func(rcvConn *TTYProtocolWS) bool {
rcvConn.SetWinSize(cols, rows)
session.forEachReceiverLock(func(rcvConn *TTYProtocolWriter) bool {
_, e := rcvConn.WriteRawData(data)
if e != nil {
err = e
}
return true return true
}) })
return return nil
} }
func (session *ttyShareSession) Write(buff []byte) (written int, err error) { func (session *ttyShareSession) Write(data []byte) (int, error) {
msg := MsgTTYWrite{ session.forEachReceiverLock(func(rcvConn *TTYProtocolWS) bool {
Data: buff, rcvConn.Write(data)
Size: len(buff),
}
data, _ := MarshalMsg(msg)
session.forEachReceiverLock(func(rcvConn *TTYProtocolWriter) bool {
_, e := rcvConn.WriteRawData(data)
if e != nil {
err = e
}
return true return true
}) })
return len(data), nil
// TODO: fix this
written = len(buff)
return
} }
// Runs the callback cb for each of the receivers in the list of the receivers, as it was when // Runs the callback cb for each of the receivers in the list of the receivers, as it was when
// this function was called. Note that there might be receivers which might have lost // this function was called. Note that there might be receivers which might have lost
// the connection since this function was called. // the connection since this function was called.
// Return false in the callback to not continue for the rest of the receivers // Return false in the callback to not continue for the rest of the receivers
func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocolWriter) bool) { func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocolWS) bool) {
session.mainRWLock.RLock() session.mainRWLock.RLock()
// TODO: Maybe find a better way? // TODO: Maybe find a better way?
rcvsCopy := copyList(session.ttyReceiverConnections) rcvsCopy := copyList(session.ttyProtoConnections)
session.mainRWLock.RUnlock() session.mainRWLock.RUnlock()
for receiverE := rcvsCopy.Front(); receiverE != nil; receiverE = receiverE.Next() { for receiverE := rcvsCopy.Front(); receiverE != nil; receiverE = receiverE.Next() {
receiver := receiverE.Value.(*TTYProtocolWriter) receiver := receiverE.Value.(*TTYProtocolWS)
if !cb(receiver) { if !cb(receiver) {
break break
} }
} }
} }
// quick and dirty locked writer
type lockedWriter struct {
writer io.Writer
lock sync.Mutex
}
func (wl *lockedWriter) Write(data []byte) (int, error) {
wl.lock.Lock()
defer wl.lock.Unlock()
return wl.writer.Write(data)
}
// Will run on the TTYReceiver connection go routine (e.g.: on the websockets connection routine) // Will run on the TTYReceiver connection go routine (e.g.: on the websockets connection routine)
// When HandleWSConnection will exit, the connection to the TTYReceiver will be closed // When HandleWSConnection will exit, the connection to the TTYReceiver will be closed
func (session *ttyShareSession) HandleWSConnection(wsConn *WSConnection) { func (session *ttyShareSession) HandleWSConnection(wsConn *websocket.Conn) {
rcvReader := NewTTYProtocolReader(wsConn) protoConn := NewTTYProtocolWS(wsConn)
// Gorilla websockets don't allow for concurent writes. Lazy, and perhaps shorter solution
// is to wrap a lock around a writer. Maybe later replace it with a channel
rcvWriter := NewTTYProtocolWriter(&lockedWriter{
writer: wsConn,
})
// Add the receiver to the list of receivers in the seesion, so we need to write-lock
session.mainRWLock.Lock() session.mainRWLock.Lock()
rcvHandleEl := session.ttyReceiverConnections.PushBack(rcvWriter) rcvHandleEl := session.ttyProtoConnections.PushBack(protoConn)
lastWindowSizeData, _ := MarshalMsg(session.lastWindowSizeMsg) winSize := session.lastWindowSizeMsg
session.mainRWLock.Unlock() session.mainRWLock.Unlock()
log.Debugf("New WS connection (%s). Serving ..", wsConn.Address()) log.Debugf("New WS connection (%s). Serving ..", wsConn.RemoteAddr().String())
// Sending the initial size of the window, if we have one // Sending the initial size of the window, if we have one
rcvWriter.WriteRawData(lastWindowSizeData) protoConn.SetWinSize(winSize.Cols, winSize.Rows)
// Wait until the TTYReceiver will close the connection on its end // Wait until the TTYReceiver will close the connection on its end
for { for {
msg, err := rcvReader.ReadMessage() err := protoConn.ReadAndHandle(
func(data []byte) {
session.ttyWriter.Write(data)
},
func(cols, rows int) {
// Maybe ask the server side to refresh/repaint the tty window?
},
)
if err != nil { if err != nil {
log.Debugf("Finished the WS reading loop: %s", err.Error()) log.Debugf("Finished the WS reading loop: %s", err.Error())
break break
} }
// We only support MsgTTYWrite from the web terminal for now
if msg.Type != MsgIDWrite {
log.Warnf("Unknown message over the WS connection: type %s", msg.Type)
break
}
var msgW MsgTTYWrite
json.Unmarshal(msg.Data, &msgW)
session.ttyWriter.Write(msgW.Data)
} }
// Remove the recevier from the list of the receiver of this session, so we need to write-lock // Remove the recevier from the list of the receiver of this session, so we need to write-lock
session.mainRWLock.Lock() session.mainRWLock.Lock()
session.ttyReceiverConnections.Remove(rcvHandleEl) session.ttyProtoConnections.Remove(rcvHandleEl)
session.mainRWLock.Unlock() session.mainRWLock.Unlock()
wsConn.Close() wsConn.Close()

@ -2,19 +2,10 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt"
"io" "io"
)
type TTYProtocolReader struct {
reader io.Reader
jsonDecoder *json.Decoder
}
type TTYProtocolWriter struct { "github.com/gorilla/websocket"
writer io.Writer )
}
const ( const (
MsgIDWrite = "Write" MsgIDWrite = "Write"
@ -22,7 +13,7 @@ const (
) )
// Message used to encapsulate the rest of the bessages bellow // Message used to encapsulate the rest of the bessages bellow
type MsgAll struct { type MsgWrapper struct {
Type string Type string
Data []byte Data []byte
} }
@ -37,26 +28,21 @@ type MsgTTYWinSize struct {
Rows int Rows int
} }
func ReadAndUnmarshalMsg(reader io.Reader, aMessage interface{}) (err error) { type OnMsgWrite func(data []byte)
var wrapperMsg MsgAll type OnMsgWinSize func(cols, rows int)
// Wait here for the right message to come
dec := json.NewDecoder(reader)
err = dec.Decode(&wrapperMsg)
if err != nil {
return errors.New("Cannot decode top message: " + err.Error())
}
err = json.Unmarshal(wrapperMsg.Data, aMessage) type TTYProtocolWS struct {
ws *websocket.Conn
}
if err != nil { func NewTTYProtocolWS(ws *websocket.Conn) *TTYProtocolWS {
return errors.New("Cannot decode message: " + err.Error() + string(wrapperMsg.Data)) return &TTYProtocolWS{
ws: ws,
} }
return
} }
func MarshalMsg(aMessage interface{}) (_ []byte, err error) { func marshalMsg(aMessage interface{}) (_ []byte, err error) {
var msg MsgAll var msg MsgWrapper
if writeMsg, ok := aMessage.(MsgTTYWrite); ok { if writeMsg, ok := aMessage.(MsgTTYWrite); ok {
msg.Type = MsgIDWrite msg.Type = MsgIDWrite
@ -80,64 +66,62 @@ func MarshalMsg(aMessage interface{}) (_ []byte, err error) {
return nil, nil return nil, nil
} }
func MarshalAndWriteMsg(writer io.Writer, aMessage interface{}) (err error) {
b, err := MarshalMsg(aMessage)
func (handler *TTYProtocolWS) ReadAndHandle(onWrite OnMsgWrite, onWinSize OnMsgWinSize) (err error) {
var msg MsgWrapper
_, r, err := handler.ws.NextReader()
if err != nil { if err != nil {
return // underlaying conn is closed. signal that through io.EOF
return io.EOF
} }
n, err := writer.Write(b) err = json.NewDecoder(r).Decode(&msg)
if n != len(b) {
err = fmt.Errorf("Unable to write : wrote %d out of %d bytes", n, len(b))
return
}
if err != nil { if err != nil {
return return
} }
return switch msg.Type {
} case MsgIDWrite:
var msgWrite MsgTTYWrite
func NewTTYProtocolWriter(w io.Writer) *TTYProtocolWriter { err = json.Unmarshal(msg.Data, &msgWrite)
return &TTYProtocolWriter{ if err == nil {
writer: w, onWrite(msgWrite.Data)
} }
} case MsgIDWinSize:
var msgRemoteWinSize MsgTTYWinSize
func NewTTYProtocolReader(r io.Reader) *TTYProtocolReader { err = json.Unmarshal(msg.Data, &msgRemoteWinSize)
return &TTYProtocolReader{ if err == nil {
reader: r, onWinSize(msgRemoteWinSize.Cols, msgRemoteWinSize.Rows)
jsonDecoder: json.NewDecoder(r), }
} }
}
func (reader *TTYProtocolReader) ReadMessage() (msg MsgAll, err error) {
// TODO: perhaps read here the error, and transform it to something that's understandable
// from the outside in the context of this object
err = reader.jsonDecoder.Decode(&msg)
return return
} }
func (writer *TTYProtocolWriter) SetWinSize(cols, rows int) error { func (handler *TTYProtocolWS) SetWinSize(cols, rows int) error {
msgWinChanged := MsgTTYWinSize{ msgWinChanged := MsgTTYWinSize{
Cols: cols, Cols: cols,
Rows: rows, Rows: rows,
} }
return MarshalAndWriteMsg(writer.writer, msgWinChanged) data, err := marshalMsg(msgWinChanged)
if err != nil {
return err
}
return handler.ws.WriteMessage(websocket.TextMessage, data)
} }
// Function to send data from one the sender to the server and the other way around. // Function to send data from one the sender to the server and the other way around.
func (writer *TTYProtocolWriter) Write(buff []byte) (int, error) { func (handler *TTYProtocolWS) Write(buff []byte) (int, error) {
msgWrite := MsgTTYWrite{ msgWrite := MsgTTYWrite{
Data: buff, Data: buff,
Size: len(buff), Size: len(buff),
} }
return len(buff), MarshalAndWriteMsg(writer.writer, msgWrite) data, err := marshalMsg(msgWrite)
} if err != nil {
return 0, err
}
func (writer *TTYProtocolWriter) WriteRawData(buff []byte) (int, error) { return len(buff), handler.ws.WriteMessage(websocket.TextMessage, data)
return writer.writer.Write(buff)
} }

@ -1,43 +0,0 @@
package server
import (
"github.com/gorilla/websocket"
)
type WSConnection struct {
connection *websocket.Conn
address string
}
func newWSConnection(conn *websocket.Conn) *WSConnection {
return &WSConnection{
connection: conn,
address: conn.RemoteAddr().String(),
}
}
func (handle *WSConnection) Write(data []byte) (n int, err error) {
w, err := handle.connection.NextWriter(websocket.TextMessage)
if err != nil {
return 0, err
}
n, err = w.Write(data)
w.Close()
return
}
func (handle *WSConnection) Close() (err error) {
return handle.connection.Close()
}
func (handle *WSConnection) Address() string {
return handle.address
}
func (handle *WSConnection) Read(data []byte) (int, error) {
_, r, err := handle.connection.NextReader()
if err != nil {
return 0, err
}
return r.Read(data)
}
Loading…
Cancel
Save