Refactor the messages marshaling and unmarshaling

pull/25/head
Vasile Popescu 4 years ago committed by Elis Popescu
parent 3c77e059b3
commit 366edcd23e

@ -1,7 +1,6 @@
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
@ -12,7 +11,7 @@ import (
"sync/atomic"
"syscall"
ttyServer "github.com/elisescu/tty-share/server"
"github.com/elisescu/tty-share/server"
"github.com/gorilla/websocket"
"github.com/moby/term"
log "github.com/sirupsen/logrus"
@ -20,7 +19,7 @@ import (
type ttyShareClient struct {
url string
connection *websocket.Conn
wsConn *websocket.Conn
detachKeys string
wcChan chan os.Signal
writeFlag uint32 // used with atomic
@ -36,7 +35,7 @@ type ttyShareClient struct {
func newTtyShareClient(url string, detachKeys string) *ttyShareClient {
return &ttyShareClient{
url: url,
connection: nil,
wsConn: nil,
detachKeys: detachKeys,
wcChan: make(chan os.Signal, 1),
writeFlag: 1,
@ -47,15 +46,6 @@ func clearScreen() {
fmt.Fprintf(os.Stdout, "\033[H\033[2J")
}
type wsTextWriter struct {
conn *websocket.Conn
}
func (w *wsTextWriter) Write(data []byte) (n int, err error) {
err = w.conn.WriteMessage(websocket.TextMessage, data)
return len(data), err
}
type keyListener struct {
wrappedReader io.Reader
}
@ -120,7 +110,7 @@ func (c *ttyShareClient) Run() (err error) {
log.Debugf("Built the WS URL from the headers: %s", wsURL)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
c.wsConn, _, err = websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
return
}
@ -133,15 +123,12 @@ func (c *ttyShareClient) Run() (err error) {
state, err := term.MakeRaw(os.Stdin.Fd())
defer term.RestoreTerminal(os.Stdin.Fd(), state)
c.connection = conn
clearScreen()
// start monitoring the size of the terminal
signal.Notify(c.wcChan, syscall.SIGWINCH)
defer signal.Stop(c.wcChan)
monitorWinChanges := func() {
// start monitoring the size of the terminal
signal.Notify(c.wcChan, syscall.SIGWINCH)
for {
select {
case <-c.wcChan:
@ -152,58 +139,48 @@ func (c *ttyShareClient) Run() (err error) {
}
}
readLoop := func() {
for {
var msg ttyServer.MsgAll
_, r, err := conn.NextReader()
if err != nil {
log.Debugf("Connection closed")
return
}
err = json.NewDecoder(r).Decode(&msg)
if err != nil {
log.Errorf("Cannot read JSON: %s", err.Error())
}
protoWS := server.NewTTYProtocolWS(c.wsConn)
switch msg.Type {
case ttyServer.MsgIDWrite:
var msgWrite ttyServer.MsgTTYWrite
err := json.Unmarshal(msg.Data, &msgWrite)
readLoop := func() {
if err != nil {
log.Errorf("Cannot read JSON: %s", err.Error())
}
var err error
for {
err = protoWS.ReadAndHandle(
// onWrite
func(data []byte) {
if atomic.LoadUint32(&c.writeFlag) != 0 {
os.Stdout.Write(data)
}
},
// onWindowSize
func(cols, rows int) {
c.winSizesMutex.Lock()
c.winSizes.remoteW = uint16(cols)
c.winSizes.remoteH = uint16(rows)
c.winSizesMutex.Unlock()
c.updateThisWinSize()
c.updateAndDecideStdoutMuted()
},
)
if atomic.LoadUint32(&c.writeFlag) != 0 {
os.Stdout.Write(msgWrite.Data)
}
case ttyServer.MsgIDWinSize:
var msgRemoteWinSize ttyServer.MsgTTYWinSize
err := json.Unmarshal(msg.Data, &msgRemoteWinSize)
if err != nil {
continue
if err != nil {
log.Errorf("Error parsing remote message: %s", err.Error())
if err == io.EOF {
// Remote WS connection closed
return
}
c.winSizesMutex.Lock()
c.winSizes.remoteW = uint16(msgRemoteWinSize.Cols)
c.winSizes.remoteH = uint16(msgRemoteWinSize.Rows)
c.winSizesMutex.Unlock()
c.updateThisWinSize()
c.updateAndDecideStdoutMuted()
}
}
}
writeLoop := func() {
ww := &wsTextWriter{
conn: conn,
}
kl := &keyListener{
wrappedReader: term.NewEscapeProxy(os.Stdin, detachBytes),
}
_, err := io.Copy(ttyServer.NewTTYProtocolWriter(ww), kl)
_, err := io.Copy(protoWS, kl)
if err != nil {
log.Debugf("Connection closed")
log.Debugf("Connection closed: %s", err.Error())
c.Stop()
return
}
@ -218,5 +195,6 @@ func (c *ttyShareClient) Run() (err error) {
}
func (c *ttyShareClient) Stop() {
c.connection.Close()
c.wsConn.Close()
signal.Stop(c.wcChan)
}

@ -147,7 +147,7 @@ func (server *TTYServer) handleWebsocket(w http.ResponseWriter, r *http.Request)
}
server.newClientCB(conn.RemoteAddr().String())
server.session.HandleWSConnection(newWSConnection(conn))
server.session.HandleWSConnection(conn)
}
func panicIfErr(err error) {

@ -2,30 +2,31 @@ package server
import (
"container/list"
"encoding/json"
"io"
"sync"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
)
type ttyShareSession struct {
mainRWLock sync.RWMutex
ttyReceiverConnections *list.List
isAlive bool
lastWindowSizeMsg MsgTTYWinSize
ttyWriter io.Writer
mainRWLock sync.RWMutex
ttyProtoConnections *list.List
isAlive bool
lastWindowSizeMsg MsgTTYWinSize
ttyWriter io.Writer
}
func newTTYShareSession(ttyWriter io.Writer) *ttyShareSession {
ttyShareSession := &ttyShareSession{
ttyReceiverConnections: list.New(),
ttyWriter: ttyWriter,
}
// quick and dirty locked writer
type lockedWriter struct {
writer io.Writer
lock sync.Mutex
}
return ttyShareSession
func (wl *lockedWriter) Write(data []byte) (int, error) {
wl.lock.Lock()
defer wl.lock.Unlock()
return wl.writer.Write(data)
}
func copyList(l *list.List) *list.List {
@ -36,123 +37,89 @@ func copyList(l *list.List) *list.List {
return newList
}
func (session *ttyShareSession) WindowSize(cols, rows int) (err error) {
msg := MsgTTYWinSize{
Cols: cols,
Rows: rows,
func newTTYShareSession(ttyWriter io.Writer) *ttyShareSession {
ttyShareSession := &ttyShareSession{
ttyProtoConnections: list.New(),
ttyWriter: ttyWriter,
}
return ttyShareSession
}
func (session *ttyShareSession) WindowSize(cols, rows int) error {
session.mainRWLock.Lock()
session.lastWindowSizeMsg = msg
session.lastWindowSizeMsg = MsgTTYWinSize{Cols: cols, Rows: rows}
session.mainRWLock.Unlock()
data, _ := MarshalMsg(msg)
session.forEachReceiverLock(func(rcvConn *TTYProtocolWriter) bool {
_, e := rcvConn.WriteRawData(data)
if e != nil {
err = e
}
session.forEachReceiverLock(func(rcvConn *TTYProtocolWS) bool {
rcvConn.SetWinSize(cols, rows)
return true
})
return
return nil
}
func (session *ttyShareSession) Write(buff []byte) (written int, err error) {
msg := MsgTTYWrite{
Data: buff,
Size: len(buff),
}
data, _ := MarshalMsg(msg)
session.forEachReceiverLock(func(rcvConn *TTYProtocolWriter) bool {
_, e := rcvConn.WriteRawData(data)
if e != nil {
err = e
}
func (session *ttyShareSession) Write(data []byte) (int, error) {
session.forEachReceiverLock(func(rcvConn *TTYProtocolWS) bool {
rcvConn.Write(data)
return true
})
// TODO: fix this
written = len(buff)
return
return len(data), nil
}
// Runs the callback cb for each of the receivers in the list of the receivers, as it was when
// this function was called. Note that there might be receivers which might have lost
// the connection since this function was called.
// Return false in the callback to not continue for the rest of the receivers
func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocolWriter) bool) {
func (session *ttyShareSession) forEachReceiverLock(cb func(rcvConn *TTYProtocolWS) bool) {
session.mainRWLock.RLock()
// TODO: Maybe find a better way?
rcvsCopy := copyList(session.ttyReceiverConnections)
rcvsCopy := copyList(session.ttyProtoConnections)
session.mainRWLock.RUnlock()
for receiverE := rcvsCopy.Front(); receiverE != nil; receiverE = receiverE.Next() {
receiver := receiverE.Value.(*TTYProtocolWriter)
receiver := receiverE.Value.(*TTYProtocolWS)
if !cb(receiver) {
break
}
}
}
// quick and dirty locked writer
type lockedWriter struct {
writer io.Writer
lock sync.Mutex
}
func (wl *lockedWriter) Write(data []byte) (int, error) {
wl.lock.Lock()
defer wl.lock.Unlock()
return wl.writer.Write(data)
}
// Will run on the TTYReceiver connection go routine (e.g.: on the websockets connection routine)
// When HandleWSConnection will exit, the connection to the TTYReceiver will be closed
func (session *ttyShareSession) HandleWSConnection(wsConn *WSConnection) {
rcvReader := NewTTYProtocolReader(wsConn)
// Gorilla websockets don't allow for concurent writes. Lazy, and perhaps shorter solution
// is to wrap a lock around a writer. Maybe later replace it with a channel
rcvWriter := NewTTYProtocolWriter(&lockedWriter{
writer: wsConn,
})
func (session *ttyShareSession) HandleWSConnection(wsConn *websocket.Conn) {
protoConn := NewTTYProtocolWS(wsConn)
// Add the receiver to the list of receivers in the seesion, so we need to write-lock
session.mainRWLock.Lock()
rcvHandleEl := session.ttyReceiverConnections.PushBack(rcvWriter)
lastWindowSizeData, _ := MarshalMsg(session.lastWindowSizeMsg)
rcvHandleEl := session.ttyProtoConnections.PushBack(protoConn)
winSize := session.lastWindowSizeMsg
session.mainRWLock.Unlock()
log.Debugf("New WS connection (%s). Serving ..", wsConn.Address())
log.Debugf("New WS connection (%s). Serving ..", wsConn.RemoteAddr().String())
// Sending the initial size of the window, if we have one
rcvWriter.WriteRawData(lastWindowSizeData)
protoConn.SetWinSize(winSize.Cols, winSize.Rows)
// Wait until the TTYReceiver will close the connection on its end
for {
msg, err := rcvReader.ReadMessage()
err := protoConn.ReadAndHandle(
func(data []byte) {
session.ttyWriter.Write(data)
},
func(cols, rows int) {
// Maybe ask the server side to refresh/repaint the tty window?
},
)
if err != nil {
log.Debugf("Finished the WS reading loop: %s", err.Error())
break
}
// We only support MsgTTYWrite from the web terminal for now
if msg.Type != MsgIDWrite {
log.Warnf("Unknown message over the WS connection: type %s", msg.Type)
break
}
var msgW MsgTTYWrite
json.Unmarshal(msg.Data, &msgW)
session.ttyWriter.Write(msgW.Data)
}
// Remove the recevier from the list of the receiver of this session, so we need to write-lock
session.mainRWLock.Lock()
session.ttyReceiverConnections.Remove(rcvHandleEl)
session.ttyProtoConnections.Remove(rcvHandleEl)
session.mainRWLock.Unlock()
wsConn.Close()

@ -2,19 +2,10 @@ package server
import (
"encoding/json"
"errors"
"fmt"
"io"
)
type TTYProtocolReader struct {
reader io.Reader
jsonDecoder *json.Decoder
}
type TTYProtocolWriter struct {
writer io.Writer
}
"github.com/gorilla/websocket"
)
const (
MsgIDWrite = "Write"
@ -22,7 +13,7 @@ const (
)
// Message used to encapsulate the rest of the bessages bellow
type MsgAll struct {
type MsgWrapper struct {
Type string
Data []byte
}
@ -37,26 +28,21 @@ type MsgTTYWinSize struct {
Rows int
}
func ReadAndUnmarshalMsg(reader io.Reader, aMessage interface{}) (err error) {
var wrapperMsg MsgAll
// Wait here for the right message to come
dec := json.NewDecoder(reader)
err = dec.Decode(&wrapperMsg)
if err != nil {
return errors.New("Cannot decode top message: " + err.Error())
}
type OnMsgWrite func(data []byte)
type OnMsgWinSize func(cols, rows int)
err = json.Unmarshal(wrapperMsg.Data, aMessage)
type TTYProtocolWS struct {
ws *websocket.Conn
}
if err != nil {
return errors.New("Cannot decode message: " + err.Error() + string(wrapperMsg.Data))
func NewTTYProtocolWS(ws *websocket.Conn) *TTYProtocolWS {
return &TTYProtocolWS{
ws: ws,
}
return
}
func MarshalMsg(aMessage interface{}) (_ []byte, err error) {
var msg MsgAll
func marshalMsg(aMessage interface{}) (_ []byte, err error) {
var msg MsgWrapper
if writeMsg, ok := aMessage.(MsgTTYWrite); ok {
msg.Type = MsgIDWrite
@ -80,64 +66,62 @@ func MarshalMsg(aMessage interface{}) (_ []byte, err error) {
return nil, nil
}
func MarshalAndWriteMsg(writer io.Writer, aMessage interface{}) (err error) {
b, err := MarshalMsg(aMessage)
func (handler *TTYProtocolWS) ReadAndHandle(onWrite OnMsgWrite, onWinSize OnMsgWinSize) (err error) {
var msg MsgWrapper
_, r, err := handler.ws.NextReader()
if err != nil {
return
// underlaying conn is closed. signal that through io.EOF
return io.EOF
}
n, err := writer.Write(b)
if n != len(b) {
err = fmt.Errorf("Unable to write : wrote %d out of %d bytes", n, len(b))
return
}
err = json.NewDecoder(r).Decode(&msg)
if err != nil {
return
}
return
}
func NewTTYProtocolWriter(w io.Writer) *TTYProtocolWriter {
return &TTYProtocolWriter{
writer: w,
}
}
func NewTTYProtocolReader(r io.Reader) *TTYProtocolReader {
return &TTYProtocolReader{
reader: r,
jsonDecoder: json.NewDecoder(r),
switch msg.Type {
case MsgIDWrite:
var msgWrite MsgTTYWrite
err = json.Unmarshal(msg.Data, &msgWrite)
if err == nil {
onWrite(msgWrite.Data)
}
case MsgIDWinSize:
var msgRemoteWinSize MsgTTYWinSize
err = json.Unmarshal(msg.Data, &msgRemoteWinSize)
if err == nil {
onWinSize(msgRemoteWinSize.Cols, msgRemoteWinSize.Rows)
}
}
}
func (reader *TTYProtocolReader) ReadMessage() (msg MsgAll, err error) {
// TODO: perhaps read here the error, and transform it to something that's understandable
// from the outside in the context of this object
err = reader.jsonDecoder.Decode(&msg)
return
}
func (writer *TTYProtocolWriter) SetWinSize(cols, rows int) error {
func (handler *TTYProtocolWS) SetWinSize(cols, rows int) error {
msgWinChanged := MsgTTYWinSize{
Cols: cols,
Rows: rows,
}
return MarshalAndWriteMsg(writer.writer, msgWinChanged)
data, err := marshalMsg(msgWinChanged)
if err != nil {
return err
}
return handler.ws.WriteMessage(websocket.TextMessage, data)
}
// Function to send data from one the sender to the server and the other way around.
func (writer *TTYProtocolWriter) Write(buff []byte) (int, error) {
func (handler *TTYProtocolWS) Write(buff []byte) (int, error) {
msgWrite := MsgTTYWrite{
Data: buff,
Size: len(buff),
}
return len(buff), MarshalAndWriteMsg(writer.writer, msgWrite)
}
data, err := marshalMsg(msgWrite)
if err != nil {
return 0, err
}
func (writer *TTYProtocolWriter) WriteRawData(buff []byte) (int, error) {
return writer.writer.Write(buff)
return len(buff), handler.ws.WriteMessage(websocket.TextMessage, data)
}

@ -1,43 +0,0 @@
package server
import (
"github.com/gorilla/websocket"
)
type WSConnection struct {
connection *websocket.Conn
address string
}
func newWSConnection(conn *websocket.Conn) *WSConnection {
return &WSConnection{
connection: conn,
address: conn.RemoteAddr().String(),
}
}
func (handle *WSConnection) Write(data []byte) (n int, err error) {
w, err := handle.connection.NextWriter(websocket.TextMessage)
if err != nil {
return 0, err
}
n, err = w.Write(data)
w.Close()
return
}
func (handle *WSConnection) Close() (err error) {
return handle.connection.Close()
}
func (handle *WSConnection) Address() string {
return handle.address
}
func (handle *WSConnection) Read(data []byte) (int, error) {
_, r, err := handle.connection.NextReader()
if err != nil {
return 0, err
}
return r.Read(data)
}
Loading…
Cancel
Save