From 366edcd23e1ad0e6ad3c1ea97133e0988b3b9c3a Mon Sep 17 00:00:00 2001 From: Vasile Popescu Date: Sat, 10 Oct 2020 11:10:17 +0200 Subject: [PATCH] Refactor the messages marshaling and unmarshaling --- client.go | 96 +++++++++-------------- server/server.go | 2 +- server/session.go | 135 ++++++++++++-------------------- server/tty_protocol_rw.go | 108 +++++++++++-------------- server/websockets_connection.go | 43 ---------- 5 files changed, 135 insertions(+), 249 deletions(-) delete mode 100644 server/websockets_connection.go diff --git a/client.go b/client.go index 8e562a0..c2fce48 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,6 @@ package main import ( - "encoding/json" "fmt" "io" "net/http" @@ -12,7 +11,7 @@ import ( "sync/atomic" "syscall" - ttyServer "github.com/elisescu/tty-share/server" + "github.com/elisescu/tty-share/server" "github.com/gorilla/websocket" "github.com/moby/term" log "github.com/sirupsen/logrus" @@ -20,7 +19,7 @@ import ( type ttyShareClient struct { url string - connection *websocket.Conn + wsConn *websocket.Conn detachKeys string wcChan chan os.Signal writeFlag uint32 // used with atomic @@ -36,7 +35,7 @@ type ttyShareClient struct { func newTtyShareClient(url string, detachKeys string) *ttyShareClient { return &ttyShareClient{ url: url, - connection: nil, + wsConn: nil, detachKeys: detachKeys, wcChan: make(chan os.Signal, 1), writeFlag: 1, @@ -47,15 +46,6 @@ func clearScreen() { fmt.Fprintf(os.Stdout, "\033[H\033[2J") } -type wsTextWriter struct { - conn *websocket.Conn -} - -func (w *wsTextWriter) Write(data []byte) (n int, err error) { - err = w.conn.WriteMessage(websocket.TextMessage, data) - return len(data), err -} - type keyListener struct { wrappedReader io.Reader } @@ -120,7 +110,7 @@ func (c *ttyShareClient) Run() (err error) { log.Debugf("Built the WS URL from the headers: %s", wsURL) - conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + c.wsConn, _, err = websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { return } @@ -133,15 +123,12 @@ func (c *ttyShareClient) Run() (err error) { state, err := term.MakeRaw(os.Stdin.Fd()) defer term.RestoreTerminal(os.Stdin.Fd(), state) - - c.connection = conn - clearScreen() - // start monitoring the size of the terminal - signal.Notify(c.wcChan, syscall.SIGWINCH) - defer signal.Stop(c.wcChan) monitorWinChanges := func() { + // start monitoring the size of the terminal + signal.Notify(c.wcChan, syscall.SIGWINCH) + for { select { case <-c.wcChan: @@ -152,58 +139,48 @@ func (c *ttyShareClient) Run() (err error) { } } - readLoop := func() { - for { - var msg ttyServer.MsgAll - _, r, err := conn.NextReader() - if err != nil { - log.Debugf("Connection closed") - return - } - err = json.NewDecoder(r).Decode(&msg) - if err != nil { - log.Errorf("Cannot read JSON: %s", err.Error()) - } + protoWS := server.NewTTYProtocolWS(c.wsConn) - switch msg.Type { - case ttyServer.MsgIDWrite: - var msgWrite ttyServer.MsgTTYWrite - err := json.Unmarshal(msg.Data, &msgWrite) + readLoop := func() { - if err != nil { - log.Errorf("Cannot read JSON: %s", err.Error()) - } + var err error + for { + err = protoWS.ReadAndHandle( + // onWrite + func(data []byte) { + if atomic.LoadUint32(&c.writeFlag) != 0 { + os.Stdout.Write(data) + } + }, + // onWindowSize + func(cols, rows int) { + c.winSizesMutex.Lock() + c.winSizes.remoteW = uint16(cols) + c.winSizes.remoteH = uint16(rows) + c.winSizesMutex.Unlock() + c.updateThisWinSize() + c.updateAndDecideStdoutMuted() + }, + ) - if atomic.LoadUint32(&c.writeFlag) != 0 { - os.Stdout.Write(msgWrite.Data) - } - case ttyServer.MsgIDWinSize: - var msgRemoteWinSize ttyServer.MsgTTYWinSize - err := json.Unmarshal(msg.Data, &msgRemoteWinSize) - if err != nil { - continue + if err != nil { + log.Errorf("Error parsing remote message: %s", err.Error()) + if err == io.EOF { + // Remote WS connection closed + return } - c.winSizesMutex.Lock() - c.winSizes.remoteW = uint16(msgRemoteWinSize.Cols) - c.winSizes.remoteH = uint16(msgRemoteWinSize.Rows) - c.winSizesMutex.Unlock() - c.updateThisWinSize() - c.updateAndDecideStdoutMuted() } } } writeLoop := func() { - ww := &wsTextWriter{ - conn: conn, - } kl := &keyListener{ wrappedReader: term.NewEscapeProxy(os.Stdin, detachBytes), } - _, err := io.Copy(ttyServer.NewTTYProtocolWriter(ww), kl) + _, err := io.Copy(protoWS, kl) if err != nil { - log.Debugf("Connection closed") + log.Debugf("Connection closed: %s", err.Error()) c.Stop() return } @@ -218,5 +195,6 @@ func (c *ttyShareClient) Run() (err error) { } func (c *ttyShareClient) Stop() { - c.connection.Close() + c.wsConn.Close() + signal.Stop(c.wcChan) } diff --git a/server/server.go b/server/server.go index 2830bc8..51d87bf 100644 --- a/server/server.go +++ b/server/server.go @@ -147,7 +147,7 @@ func (server *TTYServer) handleWebsocket(w http.ResponseWriter, r *http.Request) } server.newClientCB(conn.RemoteAddr().String()) - server.session.HandleWSConnection(newWSConnection(conn)) + server.session.HandleWSConnection(conn) } func panicIfErr(err error) { diff --git a/server/session.go b/server/session.go index 340f33b..d66282d 100644 --- a/server/session.go +++ b/server/session.go @@ -2,30 +2,31 @@ package server import ( "container/list" - "encoding/json" - "io" "sync" + "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" ) type ttyShareSession struct { - mainRWLock sync.RWMutex - ttyReceiverConnections *list.List - isAlive bool - lastWindowSizeMsg MsgTTYWinSize - ttyWriter io.Writer + mainRWLock sync.RWMutex + ttyProtoConnections *list.List + isAlive bool + lastWindowSizeMsg MsgTTYWinSize + ttyWriter io.Writer } -func newTTYShareSession(ttyWriter io.Writer) *ttyShareSession { - - ttyShareSession := &ttyShareSession{ - ttyReceiverConnections: list.New(), - ttyWriter: ttyWriter, - } +// quick and dirty locked writer +type lockedWriter struct { + writer io.Writer + lock sync.Mutex +} - return ttyShareSession +func (wl *lockedWriter) Write(data []byte) (int, error) { + wl.lock.Lock() + defer wl.lock.Unlock() + return wl.writer.Write(data) } func copyList(l *list.List) *list.List { @@ -36,123 +37,89 @@ func copyList(l *list.List) *list.List { return newList } -func (session *ttyShareSession) WindowSize(cols, rows int) (err error) { - msg := MsgTTYWinSize{ - Cols: cols, - Rows: rows, +func newTTYShareSession(ttyWriter io.Writer) *ttyShareSession { + + ttyShareSession := &ttyShareSession{ + ttyProtoConnections: list.New(), + ttyWriter: ttyWriter, } + return ttyShareSession +} + +func (session *ttyShareSession) WindowSize(cols, rows int) error { session.mainRWLock.Lock() - session.lastWindowSizeMsg = msg + session.lastWindowSizeMsg = MsgTTYWinSize{Cols: cols, Rows: rows} session.mainRWLock.Unlock() - data, _ := MarshalMsg(msg) - - session.forEachReceiverLock(func(rcvConn *TTYProtocolWriter) bool { - _, e := rcvConn.WriteRawData(data) - if e != nil { - err = e - } + session.forEachReceiverLock(func(rcvConn *TTYProtocolWS) bool { + rcvConn.SetWinSize(cols, rows) return true }) - return + return nil } -func (session *ttyShareSession) Write(buff []byte) (written int, err error) { - msg := MsgTTYWrite{ - Data: buff, - Size: len(buff), - } - - data, _ := MarshalMsg(msg) - - session.forEachReceiverLock(func(rcvConn *TTYProtocolWriter) bool { - _, e := rcvConn.WriteRawData(data) - if e != nil { - err = e - } +func (session *ttyShareSession) Write(data []byte) (int, error) { + session.forEachReceiverLock(func(rcvConn *TTYProtocolWS) bool { + rcvConn.Write(data) return true }) - - // TODO: fix this - written = len(buff) - return + return len(data), nil } // Runs the callback cb for each of the receivers in the list of the receivers, as it was when // this function was called. Note that there might be receivers which might have lost // the connection since this function was called. // Return false in the callback to not continue for the rest of the receivers -func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocolWriter) bool) { +func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocolWS) bool) { session.mainRWLock.RLock() // TODO: Maybe find a better way? - rcvsCopy := copyList(session.ttyReceiverConnections) + rcvsCopy := copyList(session.ttyProtoConnections) session.mainRWLock.RUnlock() for receiverE := rcvsCopy.Front(); receiverE != nil; receiverE = receiverE.Next() { - receiver := receiverE.Value.(*TTYProtocolWriter) + receiver := receiverE.Value.(*TTYProtocolWS) if !cb(receiver) { break } } } -// quick and dirty locked writer -type lockedWriter struct { - writer io.Writer - lock sync.Mutex -} - -func (wl *lockedWriter) Write(data []byte) (int, error) { - wl.lock.Lock() - defer wl.lock.Unlock() - return wl.writer.Write(data) -} - // Will run on the TTYReceiver connection go routine (e.g.: on the websockets connection routine) // When HandleWSConnection will exit, the connection to the TTYReceiver will be closed -func (session *ttyShareSession) HandleWSConnection(wsConn *WSConnection) { - rcvReader := NewTTYProtocolReader(wsConn) - - // Gorilla websockets don't allow for concurent writes. Lazy, and perhaps shorter solution - // is to wrap a lock around a writer. Maybe later replace it with a channel - rcvWriter := NewTTYProtocolWriter(&lockedWriter{ - writer: wsConn, - }) +func (session *ttyShareSession) HandleWSConnection(wsConn *websocket.Conn) { + protoConn := NewTTYProtocolWS(wsConn) - // Add the receiver to the list of receivers in the seesion, so we need to write-lock session.mainRWLock.Lock() - rcvHandleEl := session.ttyReceiverConnections.PushBack(rcvWriter) - lastWindowSizeData, _ := MarshalMsg(session.lastWindowSizeMsg) + rcvHandleEl := session.ttyProtoConnections.PushBack(protoConn) + winSize := session.lastWindowSizeMsg session.mainRWLock.Unlock() - log.Debugf("New WS connection (%s). Serving ..", wsConn.Address()) + log.Debugf("New WS connection (%s). Serving ..", wsConn.RemoteAddr().String()) // Sending the initial size of the window, if we have one - rcvWriter.WriteRawData(lastWindowSizeData) + protoConn.SetWinSize(winSize.Cols, winSize.Rows) // Wait until the TTYReceiver will close the connection on its end for { - msg, err := rcvReader.ReadMessage() + err := protoConn.ReadAndHandle( + func(data []byte) { + session.ttyWriter.Write(data) + }, + func(cols, rows int) { + // Maybe ask the server side to refresh/repaint the tty window? + }, + ) + if err != nil { log.Debugf("Finished the WS reading loop: %s", err.Error()) break } - - // We only support MsgTTYWrite from the web terminal for now - if msg.Type != MsgIDWrite { - log.Warnf("Unknown message over the WS connection: type %s", msg.Type) - break - } - - var msgW MsgTTYWrite - json.Unmarshal(msg.Data, &msgW) - session.ttyWriter.Write(msgW.Data) } // Remove the recevier from the list of the receiver of this session, so we need to write-lock session.mainRWLock.Lock() - session.ttyReceiverConnections.Remove(rcvHandleEl) + session.ttyProtoConnections.Remove(rcvHandleEl) session.mainRWLock.Unlock() wsConn.Close() diff --git a/server/tty_protocol_rw.go b/server/tty_protocol_rw.go index 00af6f2..1ec4a0a 100644 --- a/server/tty_protocol_rw.go +++ b/server/tty_protocol_rw.go @@ -2,19 +2,10 @@ package server import ( "encoding/json" - "errors" - "fmt" "io" -) - -type TTYProtocolReader struct { - reader io.Reader - jsonDecoder *json.Decoder -} -type TTYProtocolWriter struct { - writer io.Writer -} + "github.com/gorilla/websocket" +) const ( MsgIDWrite = "Write" @@ -22,7 +13,7 @@ const ( ) // Message used to encapsulate the rest of the bessages bellow -type MsgAll struct { +type MsgWrapper struct { Type string Data []byte } @@ -37,26 +28,21 @@ type MsgTTYWinSize struct { Rows int } -func ReadAndUnmarshalMsg(reader io.Reader, aMessage interface{}) (err error) { - var wrapperMsg MsgAll - // Wait here for the right message to come - dec := json.NewDecoder(reader) - err = dec.Decode(&wrapperMsg) - - if err != nil { - return errors.New("Cannot decode top message: " + err.Error()) - } +type OnMsgWrite func(data []byte) +type OnMsgWinSize func(cols, rows int) - err = json.Unmarshal(wrapperMsg.Data, aMessage) +type TTYProtocolWS struct { + ws *websocket.Conn +} - if err != nil { - return errors.New("Cannot decode message: " + err.Error() + string(wrapperMsg.Data)) +func NewTTYProtocolWS(ws *websocket.Conn) *TTYProtocolWS { + return &TTYProtocolWS{ + ws: ws, } - return } -func MarshalMsg(aMessage interface{}) (_ []byte, err error) { - var msg MsgAll +func marshalMsg(aMessage interface{}) (_ []byte, err error) { + var msg MsgWrapper if writeMsg, ok := aMessage.(MsgTTYWrite); ok { msg.Type = MsgIDWrite @@ -80,64 +66,62 @@ func MarshalMsg(aMessage interface{}) (_ []byte, err error) { return nil, nil } -func MarshalAndWriteMsg(writer io.Writer, aMessage interface{}) (err error) { - b, err := MarshalMsg(aMessage) +func (handler *TTYProtocolWS) ReadAndHandle(onWrite OnMsgWrite, onWinSize OnMsgWinSize) (err error) { + var msg MsgWrapper + + _, r, err := handler.ws.NextReader() if err != nil { - return + // underlaying conn is closed. signal that through io.EOF + return io.EOF } - n, err := writer.Write(b) - - if n != len(b) { - err = fmt.Errorf("Unable to write : wrote %d out of %d bytes", n, len(b)) - return - } + err = json.NewDecoder(r).Decode(&msg) if err != nil { return } - return -} - -func NewTTYProtocolWriter(w io.Writer) *TTYProtocolWriter { - return &TTYProtocolWriter{ - writer: w, - } -} - -func NewTTYProtocolReader(r io.Reader) *TTYProtocolReader { - return &TTYProtocolReader{ - reader: r, - jsonDecoder: json.NewDecoder(r), + switch msg.Type { + case MsgIDWrite: + var msgWrite MsgTTYWrite + err = json.Unmarshal(msg.Data, &msgWrite) + if err == nil { + onWrite(msgWrite.Data) + } + case MsgIDWinSize: + var msgRemoteWinSize MsgTTYWinSize + err = json.Unmarshal(msg.Data, &msgRemoteWinSize) + if err == nil { + onWinSize(msgRemoteWinSize.Cols, msgRemoteWinSize.Rows) + } } -} - -func (reader *TTYProtocolReader) ReadMessage() (msg MsgAll, err error) { - // TODO: perhaps read here the error, and transform it to something that's understandable - // from the outside in the context of this object - err = reader.jsonDecoder.Decode(&msg) return } -func (writer *TTYProtocolWriter) SetWinSize(cols, rows int) error { +func (handler *TTYProtocolWS) SetWinSize(cols, rows int) error { msgWinChanged := MsgTTYWinSize{ Cols: cols, Rows: rows, } - return MarshalAndWriteMsg(writer.writer, msgWinChanged) + data, err := marshalMsg(msgWinChanged) + if err != nil { + return err + } + + return handler.ws.WriteMessage(websocket.TextMessage, data) } // Function to send data from one the sender to the server and the other way around. -func (writer *TTYProtocolWriter) Write(buff []byte) (int, error) { +func (handler *TTYProtocolWS) Write(buff []byte) (int, error) { msgWrite := MsgTTYWrite{ Data: buff, Size: len(buff), } - return len(buff), MarshalAndWriteMsg(writer.writer, msgWrite) -} + data, err := marshalMsg(msgWrite) + if err != nil { + return 0, err + } -func (writer *TTYProtocolWriter) WriteRawData(buff []byte) (int, error) { - return writer.writer.Write(buff) + return len(buff), handler.ws.WriteMessage(websocket.TextMessage, data) } diff --git a/server/websockets_connection.go b/server/websockets_connection.go deleted file mode 100644 index 5e74605..0000000 --- a/server/websockets_connection.go +++ /dev/null @@ -1,43 +0,0 @@ -package server - -import ( - "github.com/gorilla/websocket" -) - -type WSConnection struct { - connection *websocket.Conn - address string -} - -func newWSConnection(conn *websocket.Conn) *WSConnection { - return &WSConnection{ - connection: conn, - address: conn.RemoteAddr().String(), - } -} - -func (handle *WSConnection) Write(data []byte) (n int, err error) { - w, err := handle.connection.NextWriter(websocket.TextMessage) - if err != nil { - return 0, err - } - n, err = w.Write(data) - w.Close() - return -} - -func (handle *WSConnection) Close() (err error) { - return handle.connection.Close() -} - -func (handle *WSConnection) Address() string { - return handle.address -} - -func (handle *WSConnection) Read(data []byte) (int, error) { - _, r, err := handle.connection.NextReader() - if err != nil { - return 0, err - } - return r.Read(data) -}