Add CrossOrigin to the service (#68)

Adds Cross Origin as an optional flag to allow users to turn off the CORs check.
master
Michael J Feher 1 year ago committed by GitHub
parent 3c3993ca33
commit 407bfbbbbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,13 +16,14 @@ import (
var version string = "0.0.0" var version string = "0.0.0"
func createServer(frontListenAddress string, frontendPath string, pty server.PTYHandler, sessionID string, allowTunneling bool) *server.TTYServer { func createServer(frontListenAddress string, frontendPath string, pty server.PTYHandler, sessionID string, allowTunneling bool, crossOrigin bool) *server.TTYServer {
config := ttyServer.TTYServerConfig{ config := ttyServer.TTYServerConfig{
FrontListenAddress: frontListenAddress, FrontListenAddress: frontListenAddress,
FrontendPath: frontendPath, FrontendPath: frontendPath,
PTY: pty, PTY: pty,
SessionID: sessionID, SessionID: sessionID,
AllowTunneling: allowTunneling, AllowTunneling: allowTunneling,
CrossOrigin: crossOrigin,
} }
server := ttyServer.NewTTYServer(config) server := ttyServer.NewTTYServer(config)
@ -84,6 +85,7 @@ Flags:
detachKeys := flag.String("detach-keys", "ctrl-o,ctrl-c", "[c] Sequence of keys to press for closing the connection. Supported: https://godoc.org/github.com/moby/term#pkg-variables.") detachKeys := flag.String("detach-keys", "ctrl-o,ctrl-c", "[c] Sequence of keys to press for closing the connection. Supported: https://godoc.org/github.com/moby/term#pkg-variables.")
allowTunneling := flag.Bool("A", false, "[s] Allow clients to create a TCP tunnel") allowTunneling := flag.Bool("A", false, "[s] Allow clients to create a TCP tunnel")
tunnelConfig := flag.String("L", "", "[c] TCP tunneling addresses: local_port:remote_host:remote_port. The client will listen on local_port for TCP connections, and will forward those to the from the server side to remote_host:remote_port") tunnelConfig := flag.String("L", "", "[c] TCP tunneling addresses: local_port:remote_host:remote_port. The client will listen on local_port for TCP connections, and will forward those to the from the server side to remote_host:remote_port")
crossOrgin := flag.Bool("cross-origin", false, "[s] Allow cross origin requests to the server")
verbose := flag.Bool("verbose", false, "Verbose logging") verbose := flag.Bool("verbose", false, "Verbose logging")
flag.Usage = func() { flag.Usage = func() {
@ -195,7 +197,7 @@ Flags:
pty = &nilPTY{} pty = &nilPTY{}
} }
server := createServer(*listenAddress, *frontendPath, pty, sessionID, *allowTunneling) server := createServer(*listenAddress, *frontendPath, pty, sessionID, *allowTunneling, *crossOrgin)
if cols, rows, e := ptyMaster.GetWinSize(); e == nil { if cols, rows, e := ptyMaster.GetWinSize(); e == nil {
server.WindowSize(cols, rows) server.WindowSize(cols, rows)
} }

@ -41,6 +41,7 @@ type TTYServerConfig struct {
PTY PTYHandler PTY PTYHandler
SessionID string SessionID string
AllowTunneling bool AllowTunneling bool
CrossOrigin bool
} }
// TTYServer represents the instance of a tty server // TTYServer represents the instance of a tty server
@ -126,7 +127,7 @@ func NewTTYServer(config TTYServerConfig) (server *TTYServer) {
server.handleWithTemplateHtml(w, r, "tty-share.in.html", templateModel) server.handleWithTemplateHtml(w, r, "tty-share.in.html", templateModel)
}) })
routesHandler.HandleFunc(ttyWsPath, func(w http.ResponseWriter, r *http.Request) { routesHandler.HandleFunc(ttyWsPath, func(w http.ResponseWriter, r *http.Request) {
server.handleTTYWebsocket(w, r) server.handleTTYWebsocket(w, r, config.CrossOrigin)
}) })
if server.config.AllowTunneling { if server.config.AllowTunneling {
// tunnel websockets connection // tunnel websockets connection
@ -151,16 +152,25 @@ func NewTTYServer(config TTYServerConfig) (server *TTYServer) {
return server return server
} }
func (server *TTYServer) handleTTYWebsocket(w http.ResponseWriter, r *http.Request) { func (server *TTYServer) handleTTYWebsocket(w http.ResponseWriter, r *http.Request, crossOrigin bool) {
if r.Method != "GET" { if r.Method != "GET" {
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
return return
} }
upgrader := websocket.Upgrader{ upgrader := websocket.Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
} }
if crossOrigin {
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
}
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {

Loading…
Cancel
Save