diff --git a/client.go b/client.go index ddeea86..c21ad76 100644 --- a/client.go +++ b/client.go @@ -125,7 +125,7 @@ func (c *ttyShareClient) Run() (err error) { defer term.RestoreTerminal(os.Stdin.Fd(), state) clearScreen() - protoWS := server.NewTTYProtocolWS(c.wsConn) + protoWS := server.NewTTYProtocolWSLocked(c.wsConn) monitorWinChanges := func() { // start monitoring the size of the terminal diff --git a/server/session.go b/server/session.go index 2820962..4b0b75c 100644 --- a/server/session.go +++ b/server/session.go @@ -2,7 +2,6 @@ package server import ( "container/list" - "io" "sync" "github.com/gorilla/websocket" @@ -17,18 +16,6 @@ type ttyShareSession struct { ptyHandler PTYHandler } -// quick and dirty locked writer -type lockedWriter struct { - writer io.Writer - lock sync.Mutex -} - -func (wl *lockedWriter) Write(data []byte) (int, error) { - wl.lock.Lock() - defer wl.lock.Unlock() - return wl.writer.Write(data) -} - func copyList(l *list.List) *list.List { newList := list.New() for e := l.Front(); e != nil; e = e.Next() { @@ -52,7 +39,7 @@ func (session *ttyShareSession) WindowSize(cols, rows int) error { session.lastWindowSizeMsg = MsgTTYWinSize{Cols: cols, Rows: rows} session.mainRWLock.Unlock() - session.forEachReceiverLock(func(rcvConn *TTYProtocolWS) bool { + session.forEachReceiverLock(func(rcvConn *TTYProtocolWSLocked) bool { rcvConn.SetWinSize(cols, rows) return true }) @@ -60,7 +47,7 @@ func (session *ttyShareSession) WindowSize(cols, rows int) error { } func (session *ttyShareSession) Write(data []byte) (int, error) { - session.forEachReceiverLock(func(rcvConn *TTYProtocolWS) bool { + session.forEachReceiverLock(func(rcvConn *TTYProtocolWSLocked) bool { rcvConn.Write(data) return true }) @@ -71,14 +58,14 @@ func (session *ttyShareSession) Write(data []byte) (int, error) { // this function was called. Note that there might be receivers which might have lost // the connection since this function was called. // Return false in the callback to not continue for the rest of the receivers -func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocolWS) bool) { +func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocolWSLocked) bool) { session.mainRWLock.RLock() // TODO: Maybe find a better way? rcvsCopy := copyList(session.ttyProtoConnections) session.mainRWLock.RUnlock() for receiverE := rcvsCopy.Front(); receiverE != nil; receiverE = receiverE.Next() { - receiver := receiverE.Value.(*TTYProtocolWS) + receiver := receiverE.Value.(*TTYProtocolWSLocked) if !cb(receiver) { break } @@ -88,7 +75,7 @@ func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocol // Will run on the TTYReceiver connection go routine (e.g.: on the websockets connection routine) // When HandleWSConnection will exit, the connection to the TTYReceiver will be closed func (session *ttyShareSession) HandleWSConnection(wsConn *websocket.Conn) { - protoConn := NewTTYProtocolWS(wsConn) + protoConn := NewTTYProtocolWSLocked(wsConn) session.mainRWLock.Lock() rcvHandleEl := session.ttyProtoConnections.PushBack(protoConn) diff --git a/server/tty_protocol_rw.go b/server/tty_protocol_rw.go index 1ec4a0a..0033c56 100644 --- a/server/tty_protocol_rw.go +++ b/server/tty_protocol_rw.go @@ -3,6 +3,7 @@ package server import ( "encoding/json" "io" + "sync" "github.com/gorilla/websocket" ) @@ -31,12 +32,13 @@ type MsgTTYWinSize struct { type OnMsgWrite func(data []byte) type OnMsgWinSize func(cols, rows int) -type TTYProtocolWS struct { +type TTYProtocolWSLocked struct { ws *websocket.Conn + lock sync.Mutex } -func NewTTYProtocolWS(ws *websocket.Conn) *TTYProtocolWS { - return &TTYProtocolWS{ +func NewTTYProtocolWSLocked(ws *websocket.Conn) *TTYProtocolWSLocked { + return &TTYProtocolWSLocked{ ws: ws, } } @@ -67,7 +69,7 @@ func marshalMsg(aMessage interface{}) (_ []byte, err error) { } -func (handler *TTYProtocolWS) ReadAndHandle(onWrite OnMsgWrite, onWinSize OnMsgWinSize) (err error) { +func (handler *TTYProtocolWSLocked) ReadAndHandle(onWrite OnMsgWrite, onWinSize OnMsgWinSize) (err error) { var msg MsgWrapper _, r, err := handler.ws.NextReader() @@ -99,21 +101,24 @@ func (handler *TTYProtocolWS) ReadAndHandle(onWrite OnMsgWrite, onWinSize OnMsgW return } -func (handler *TTYProtocolWS) SetWinSize(cols, rows int) error { +func (handler *TTYProtocolWSLocked) SetWinSize(cols, rows int) (err error) { msgWinChanged := MsgTTYWinSize{ Cols: cols, Rows: rows, } data, err := marshalMsg(msgWinChanged) if err != nil { - return err + return } - return handler.ws.WriteMessage(websocket.TextMessage, data) + handler.lock.Lock() + err = handler.ws.WriteMessage(websocket.TextMessage, data) + handler.lock.Unlock() + return } // Function to send data from one the sender to the server and the other way around. -func (handler *TTYProtocolWS) Write(buff []byte) (int, error) { +func (handler *TTYProtocolWSLocked) Write(buff []byte) (n int, err error) { msgWrite := MsgTTYWrite{ Data: buff, Size: len(buff), @@ -123,5 +128,8 @@ func (handler *TTYProtocolWS) Write(buff []byte) (int, error) { return 0, err } - return len(buff), handler.ws.WriteMessage(websocket.TextMessage, data) + handler.lock.Lock() + n, err = len(buff), handler.ws.WriteMessage(websocket.TextMessage, data) + handler.lock.Unlock() + return }