From 8eec179329d736b0968cb1c5f9d067aab76e85da Mon Sep 17 00:00:00 2001 From: Michael J Feher Date: Fri, 21 Apr 2023 18:00:40 -0500 Subject: [PATCH] Add CrossOrigin to the service --- main.go | 6 ++++-- server/server.go | 6 ++++-- vendor/github.com/gorilla/websocket/server.go | 4 +++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/main.go b/main.go index a2584fd..c8715a4 100644 --- a/main.go +++ b/main.go @@ -16,13 +16,14 @@ import ( var version string = "0.0.0" -func createServer(frontListenAddress string, frontendPath string, pty server.PTYHandler, sessionID string, allowTunneling bool) *server.TTYServer { +func createServer(frontListenAddress string, frontendPath string, pty server.PTYHandler, sessionID string, allowTunneling bool, crossOrigin bool) *server.TTYServer { config := ttyServer.TTYServerConfig{ FrontListenAddress: frontListenAddress, FrontendPath: frontendPath, PTY: pty, SessionID: sessionID, AllowTunneling: allowTunneling, + CrossOrigin: crossOrigin, } server := ttyServer.NewTTYServer(config) @@ -84,6 +85,7 @@ Flags: detachKeys := flag.String("detach-keys", "ctrl-o,ctrl-c", "[c] Sequence of keys to press for closing the connection. Supported: https://godoc.org/github.com/moby/term#pkg-variables.") allowTunneling := flag.Bool("A", false, "[s] Allow clients to create a TCP tunnel") tunnelConfig := flag.String("L", "", "[c] TCP tunneling addresses: local_port:remote_host:remote_port. The client will listen on local_port for TCP connections, and will forward those to the from the server side to remote_host:remote_port") + crossOrgin := flag.Bool("cross-origin", false, "[s] Allow cross origin requests to the server") verbose := flag.Bool("verbose", false, "Verbose logging") flag.Usage = func() { @@ -195,7 +197,7 @@ Flags: pty = &nilPTY{} } - server := createServer(*listenAddress, *frontendPath, pty, sessionID, *allowTunneling) + server := createServer(*listenAddress, *frontendPath, pty, sessionID, *allowTunneling, *crossOrgin) if cols, rows, e := ptyMaster.GetWinSize(); e == nil { server.WindowSize(cols, rows) } diff --git a/server/server.go b/server/server.go index ffbb129..cdf1788 100644 --- a/server/server.go +++ b/server/server.go @@ -41,6 +41,7 @@ type TTYServerConfig struct { PTY PTYHandler SessionID string AllowTunneling bool + CrossOrigin bool } // TTYServer represents the instance of a tty server @@ -126,7 +127,7 @@ func NewTTYServer(config TTYServerConfig) (server *TTYServer) { server.handleWithTemplateHtml(w, r, "tty-share.in.html", templateModel) }) routesHandler.HandleFunc(ttyWsPath, func(w http.ResponseWriter, r *http.Request) { - server.handleTTYWebsocket(w, r) + server.handleTTYWebsocket(w, r, config.CrossOrigin) }) if server.config.AllowTunneling { // tunnel websockets connection @@ -151,7 +152,7 @@ func NewTTYServer(config TTYServerConfig) (server *TTYServer) { return server } -func (server *TTYServer) handleTTYWebsocket(w http.ResponseWriter, r *http.Request) { +func (server *TTYServer) handleTTYWebsocket(w http.ResponseWriter, r *http.Request, crossOrigin bool) { if r.Method != "GET" { w.WriteHeader(http.StatusForbidden) return @@ -160,6 +161,7 @@ func (server *TTYServer) handleTTYWebsocket(w http.ResponseWriter, r *http.Reque upgrader := websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, + CrossOrigin: crossOrigin, } conn, err := upgrader.Upgrade(w, r, nil) diff --git a/vendor/github.com/gorilla/websocket/server.go b/vendor/github.com/gorilla/websocket/server.go index 24d53b3..932ef3b 100644 --- a/vendor/github.com/gorilla/websocket/server.go +++ b/vendor/github.com/gorilla/websocket/server.go @@ -72,6 +72,8 @@ type Upgrader struct { // guarantee that compression will be supported. Currently only "no context // takeover" modes are supported. EnableCompression bool + + CrossOrigin bool } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { @@ -149,7 +151,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if checkOrigin == nil { checkOrigin = checkSameOrigin } - if !checkOrigin(r) { + if !checkOrigin(r) && !u.CrossOrigin { return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") }