Use a mutex to avoid concurrent writes to the WS connections

Concurrent writes to the same WS connection are not allowed by the
gorrila/websocket package.
pull/25/head
Vasile Popescu 4 years ago committed by Elis Popescu
parent e1f4cdd06d
commit 33c6017254

@ -125,7 +125,7 @@ func (c *ttyShareClient) Run() (err error) {
defer term.RestoreTerminal(os.Stdin.Fd(), state) defer term.RestoreTerminal(os.Stdin.Fd(), state)
clearScreen() clearScreen()
protoWS := server.NewTTYProtocolWS(c.wsConn) protoWS := server.NewTTYProtocolWSLocked(c.wsConn)
monitorWinChanges := func() { monitorWinChanges := func() {
// start monitoring the size of the terminal // start monitoring the size of the terminal

@ -2,7 +2,6 @@ package server
import ( import (
"container/list" "container/list"
"io"
"sync" "sync"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -17,18 +16,6 @@ type ttyShareSession struct {
ptyHandler PTYHandler ptyHandler PTYHandler
} }
// quick and dirty locked writer
type lockedWriter struct {
writer io.Writer
lock sync.Mutex
}
func (wl *lockedWriter) Write(data []byte) (int, error) {
wl.lock.Lock()
defer wl.lock.Unlock()
return wl.writer.Write(data)
}
func copyList(l *list.List) *list.List { func copyList(l *list.List) *list.List {
newList := list.New() newList := list.New()
for e := l.Front(); e != nil; e = e.Next() { for e := l.Front(); e != nil; e = e.Next() {
@ -52,7 +39,7 @@ func (session *ttyShareSession) WindowSize(cols, rows int) error {
session.lastWindowSizeMsg = MsgTTYWinSize{Cols: cols, Rows: rows} session.lastWindowSizeMsg = MsgTTYWinSize{Cols: cols, Rows: rows}
session.mainRWLock.Unlock() session.mainRWLock.Unlock()
session.forEachReceiverLock(func(rcvConn *TTYProtocolWS) bool { session.forEachReceiverLock(func(rcvConn *TTYProtocolWSLocked) bool {
rcvConn.SetWinSize(cols, rows) rcvConn.SetWinSize(cols, rows)
return true return true
}) })
@ -60,7 +47,7 @@ func (session *ttyShareSession) WindowSize(cols, rows int) error {
} }
func (session *ttyShareSession) Write(data []byte) (int, error) { func (session *ttyShareSession) Write(data []byte) (int, error) {
session.forEachReceiverLock(func(rcvConn *TTYProtocolWS) bool { session.forEachReceiverLock(func(rcvConn *TTYProtocolWSLocked) bool {
rcvConn.Write(data) rcvConn.Write(data)
return true return true
}) })
@ -71,14 +58,14 @@ func (session *ttyShareSession) Write(data []byte) (int, error) {
// this function was called. Note that there might be receivers which might have lost // this function was called. Note that there might be receivers which might have lost
// the connection since this function was called. // the connection since this function was called.
// Return false in the callback to not continue for the rest of the receivers // Return false in the callback to not continue for the rest of the receivers
func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocolWS) bool) { func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocolWSLocked) bool) {
session.mainRWLock.RLock() session.mainRWLock.RLock()
// TODO: Maybe find a better way? // TODO: Maybe find a better way?
rcvsCopy := copyList(session.ttyProtoConnections) rcvsCopy := copyList(session.ttyProtoConnections)
session.mainRWLock.RUnlock() session.mainRWLock.RUnlock()
for receiverE := rcvsCopy.Front(); receiverE != nil; receiverE = receiverE.Next() { for receiverE := rcvsCopy.Front(); receiverE != nil; receiverE = receiverE.Next() {
receiver := receiverE.Value.(*TTYProtocolWS) receiver := receiverE.Value.(*TTYProtocolWSLocked)
if !cb(receiver) { if !cb(receiver) {
break break
} }
@ -88,7 +75,7 @@ func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocol
// Will run on the TTYReceiver connection go routine (e.g.: on the websockets connection routine) // Will run on the TTYReceiver connection go routine (e.g.: on the websockets connection routine)
// When HandleWSConnection will exit, the connection to the TTYReceiver will be closed // When HandleWSConnection will exit, the connection to the TTYReceiver will be closed
func (session *ttyShareSession) HandleWSConnection(wsConn *websocket.Conn) { func (session *ttyShareSession) HandleWSConnection(wsConn *websocket.Conn) {
protoConn := NewTTYProtocolWS(wsConn) protoConn := NewTTYProtocolWSLocked(wsConn)
session.mainRWLock.Lock() session.mainRWLock.Lock()
rcvHandleEl := session.ttyProtoConnections.PushBack(protoConn) rcvHandleEl := session.ttyProtoConnections.PushBack(protoConn)

@ -3,6 +3,7 @@ package server
import ( import (
"encoding/json" "encoding/json"
"io" "io"
"sync"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
@ -31,12 +32,13 @@ type MsgTTYWinSize struct {
type OnMsgWrite func(data []byte) type OnMsgWrite func(data []byte)
type OnMsgWinSize func(cols, rows int) type OnMsgWinSize func(cols, rows int)
type TTYProtocolWS struct { type TTYProtocolWSLocked struct {
ws *websocket.Conn ws *websocket.Conn
lock sync.Mutex
} }
func NewTTYProtocolWS(ws *websocket.Conn) *TTYProtocolWS { func NewTTYProtocolWSLocked(ws *websocket.Conn) *TTYProtocolWSLocked {
return &TTYProtocolWS{ return &TTYProtocolWSLocked{
ws: ws, ws: ws,
} }
} }
@ -67,7 +69,7 @@ func marshalMsg(aMessage interface{}) (_ []byte, err error) {
} }
func (handler *TTYProtocolWS) ReadAndHandle(onWrite OnMsgWrite, onWinSize OnMsgWinSize) (err error) { func (handler *TTYProtocolWSLocked) ReadAndHandle(onWrite OnMsgWrite, onWinSize OnMsgWinSize) (err error) {
var msg MsgWrapper var msg MsgWrapper
_, r, err := handler.ws.NextReader() _, r, err := handler.ws.NextReader()
@ -99,21 +101,24 @@ func (handler *TTYProtocolWS) ReadAndHandle(onWrite OnMsgWrite, onWinSize OnMsgW
return return
} }
func (handler *TTYProtocolWS) SetWinSize(cols, rows int) error { func (handler *TTYProtocolWSLocked) SetWinSize(cols, rows int) (err error) {
msgWinChanged := MsgTTYWinSize{ msgWinChanged := MsgTTYWinSize{
Cols: cols, Cols: cols,
Rows: rows, Rows: rows,
} }
data, err := marshalMsg(msgWinChanged) data, err := marshalMsg(msgWinChanged)
if err != nil { if err != nil {
return err return
} }
return handler.ws.WriteMessage(websocket.TextMessage, data) handler.lock.Lock()
err = handler.ws.WriteMessage(websocket.TextMessage, data)
handler.lock.Unlock()
return
} }
// Function to send data from one the sender to the server and the other way around. // Function to send data from one the sender to the server and the other way around.
func (handler *TTYProtocolWS) Write(buff []byte) (int, error) { func (handler *TTYProtocolWSLocked) Write(buff []byte) (n int, err error) {
msgWrite := MsgTTYWrite{ msgWrite := MsgTTYWrite{
Data: buff, Data: buff,
Size: len(buff), Size: len(buff),
@ -123,5 +128,8 @@ func (handler *TTYProtocolWS) Write(buff []byte) (int, error) {
return 0, err return 0, err
} }
return len(buff), handler.ws.WriteMessage(websocket.TextMessage, data) handler.lock.Lock()
n, err = len(buff), handler.ws.WriteMessage(websocket.TextMessage, data)
handler.lock.Unlock()
return
} }

Loading…
Cancel
Save