diff --git a/.github/workflows/commit_lint.yml b/.github/workflows/commit_lint.yml new file mode 100644 index 0000000..2df66ba --- /dev/null +++ b/.github/workflows/commit_lint.yml @@ -0,0 +1,38 @@ +name: Lint Commit Messages + +on: + workflow_dispatch: + pull_request: + +permissions: + contents: read + +jobs: + commitlint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup node + uses: actions/setup-node@v4 + with: + node-version: lts/* + + - name: Install commitlint + run: npm install -D @commitlint/cli @commitlint/config-conventional + + - name: Print versions + run: | + git --version + node --version + npm --version + npx commitlint --version + + - name: Create default conventional commitlint config + run: | + echo "module.exports = {extends: ['@commitlint/config-conventional']};" > commitlint.config.js + + - name: Validate PR commits with commitlint + run: npx commitlint --from ${{ github.event.pull_request.base.sha }} --to ${{ github.event.pull_request.head.sha }} --verbose diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 11f6005..aec6f5c 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -14,6 +14,7 @@ builds: goarch: - amd64 - arm64 + - 386 # Add support for Windows 32-bit (x86) archives: - formats: [ 'tar.gz' ] @@ -36,6 +37,7 @@ changelog: exclude: - "^docs:" - "^test:" + - "^chore:" release: github: diff --git a/README.md b/README.md index ccacfe3..af21345 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ -> WARNING: This is still in Alpha stage, not ready for production use yet. -

mmar @@ -8,21 +6,33 @@ # mmar -mmar (pronounced "ma-mar") is a zero-dependancy, self-hostable, cross-platform HTTP tunnel that exposes your localhost to the world on a public URL. +mmar (pronounced "ma-mar") is a zero-dependency, self-hostable, cross-platform HTTP tunnel that exposes your localhost to the world on a public URL. It allows you to quickly share what you are working on locally with others without the hassle of a full deployment, especially if it is not ready to be shipped. - +### Demo + +

+ + mmar + +

### Key Features - Super simple to use -- Utilize "mmar.dev" to tunnel for free on a generated subdomain +- Provides "mmar.dev" to tunnel for free on a generated subdomain - Expose multiple ports on different subdomains - Live logs of requests coming into your localhost server -- Zero dependancies +- Zero dependencies - Self-host your own mmar server to have full control +### Limitations + +- Currently only supports the HTTP protocol, other protocols such as websockets were not tested and will likely not work +- Requests through mmar are limited to 10mb in size, however this could be made configurable in the future +- There is a limit of 5 mmar tunnels per IP to avoid abuse, this could also be made configurable in the future + ### Learn More The development, implementation and technical details of mmar has all been documented in a [devlog series](https://ymusleh.com/tags/mmar.html). You can read more about it there. @@ -31,7 +41,15 @@ _p.s. mmar means “corridor” or “pass-through” in Arabic._ ## Installation -### MacOS +### Linux/MacOS + +Install mmar + +```sh +sudo curl -sSL https://raw.githubusercontent.com/yusuf-musleh/mmar/refs/heads/master/install.sh | sh +``` + +### MacOS (Homebrew) Use [Homebrew](https://brew.sh/) to install `mmar` on MacOS: @@ -50,13 +68,9 @@ brew upgrade yusuf-musleh/mmar-tap/mmar The fastest way to create a tunnel what is running on your `localhost:8080` using [Docker](https://www.docker.com/) is by running this command: ``` -docker run --rm --network host ghcr.io/yusuf-musleh/mmar:v0.2.2 client --local-port 8080 +docker run --rm --network host ghcr.io/yusuf-musleh/mmar:v0.2.6 client --local-port 8080 ``` -### Linux - -See Docker or Manual installation instructions - ### Windows See Docker or Manual installation instructions @@ -111,6 +125,19 @@ Commands: Run `mmar -h` to get help for a specific command ``` +### Configuring through Environment Variables + +You can define the various mmar command flags in environment variables rather than passing them in with the command. Here are the available environment variables along with the corresponding flags: + +``` +MMAR__SERVER_HTTP_PORT -> mmar server --http-port +MMAR__SERVER_TCP_PORT -> mmar server --tcp-port +MMAR__LOCAL_PORT -> mmar client --local-port +MMAR__TUNNEL_HTTP_PORT -> mmar client --tunnel-http-port +MMAR__TUNNEL_TCP_PORT -> mmar client --tunnel-tcp-port +MMAR__TUNNEL_HOST -> mmar client --tunnel-host +``` + ## Self-Host Since everything is open-source, you can easily self-host mmar on your own infrastructure under your own domain. @@ -133,7 +160,7 @@ To deploy mmar on your own VPS using docker, you can do the following: ```yaml services: mmar-server: - image: "ghcr.io/yusuf-musleh/mmar:v0.2.2" # <----- make sure to use the mmar's latest version + image: "ghcr.io/yusuf-musleh/mmar:v0.2.3" # <----- make sure to use the mmar's latest version restart: unless-stopped command: server environment: @@ -144,6 +171,20 @@ To deploy mmar on your own VPS using docker, you can do the following: - "6673:6673" ``` + The `USERNAME_HASH` and `PASSWORD_HASH` env variables are the hashes of the credentials needed to access the stats page, which can be viewed at `stats.yourdomain.com`. The stats pages returns a json with very basic information about the number of clients connected (i.e. tunnels open) along with a list of the subdomains and when they were created: + + ```json + { + "connectedClients": [ + { + "createdOn": "2025-03-01T08:01:46Z", + "id": "owrwf0" + } + ], + "connectedClientsCount": 1 + } + ``` + 1. Next, we need to also add a reverse proxy, such as [Nginx](https://nginx.org/) or [Caddy](https://caddyserver.com/), so that requests and TCP connections to your domain are routed accordingly. Since the mmar client communicates with the server using TCP, you need to make sure that the reverse proxy supports routing on TCP, and not just HTTP. I highly recommend [Caddy](https://caddyserver.com/) as it also handles obtaining SSL certificates for your wildcard subdomains automatically for you, in addition to having a Layer4 reverse proxy to route TCP connections. To get this functionality we need to include a few additional Caddy modules, the [layer4 module](github.com/mholt/caddy-l4) as well as the [caddy-dns](https://github.com/caddy-dns) module that matches your domain registrar, in my case I am using the [namecheap module](https://github.com/caddy-dns/namecheap) in order to automatically issue SSL certificates for wildcard subdomains. @@ -215,7 +256,7 @@ To deploy mmar on your own VPS using docker, you can do the following: - caddy_data:/data - caddy_config:/config mmar-server: - image: "ghcr.io/yusuf-musleh/mmar:v0.2.2" # <----- make sure to use the mmar's latest version + image: "ghcr.io/yusuf-musleh/mmar:v0.2.3" # <----- make sure to use the mmar's latest version restart: unless-stopped command: server environment: diff --git a/cmd/mmar/main.go b/cmd/mmar/main.go index 87b9778..edac2bb 100644 --- a/cmd/mmar/main.go +++ b/cmd/mmar/main.go @@ -14,24 +14,46 @@ import ( func main() { serverCmd := flag.NewFlagSet(constants.SERVER_CMD, flag.ExitOnError) serverHttpPort := serverCmd.String( - "http-port", constants.SERVER_HTTP_PORT, constants.SERVER_HTTP_PORT_HELP, + "http-port", + utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_SERVER_HTTP_PORT, constants.SERVER_HTTP_PORT), + constants.SERVER_HTTP_PORT_HELP, ) serverTcpPort := serverCmd.String( - "tcp-port", constants.SERVER_TCP_PORT, constants.SERVER_TCP_PORT_HELP, + "tcp-port", + utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_SERVER_TCP_PORT, constants.SERVER_TCP_PORT), + constants.SERVER_TCP_PORT_HELP, ) clientCmd := flag.NewFlagSet(constants.CLIENT_CMD, flag.ExitOnError) clientLocalPort := clientCmd.String( - "local-port", constants.CLIENT_LOCAL_PORT, constants.CLIENT_LOCAL_PORT_HELP, + "local-port", + utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_LOCAL_PORT, constants.CLIENT_LOCAL_PORT), + constants.CLIENT_LOCAL_PORT_HELP, ) clientTunnelHttpPort := clientCmd.String( - "tunnel-http-port", constants.TUNNEL_HTTP_PORT, constants.CLIENT_HTTP_PORT_HELP, + "tunnel-http-port", + utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_TUNNEL_HTTP_PORT, constants.TUNNEL_HTTP_PORT), + constants.CLIENT_HTTP_PORT_HELP, ) clientTunnelTcpPort := clientCmd.String( - "tunnel-tcp-port", constants.SERVER_TCP_PORT, constants.CLIENT_TCP_PORT_HELP, + "tunnel-tcp-port", + utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_TUNNEL_TCP_PORT, constants.SERVER_TCP_PORT), + constants.CLIENT_TCP_PORT_HELP, ) clientTunnelHost := clientCmd.String( - "tunnel-host", constants.TUNNEL_HOST, constants.TUNNEL_HOST_HELP, + "tunnel-host", + utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_TUNNEL_HOST, constants.TUNNEL_HOST), + constants.TUNNEL_HOST_HELP, + ) + clientCustomDns := clientCmd.String( + "custom-dns", + utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_CUSTOM_DNS, ""), + constants.CLIENT_CUSTOM_DNS_HELP, + ) + clientCustomCert := clientCmd.String( + "custom-cert", + utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_CUSTOM_CERT, ""), + constants.CLIENT_CUSTOM_CERT_HELP, ) versionCmd := flag.NewFlagSet(constants.VERSION_CMD, flag.ExitOnError) @@ -59,6 +81,8 @@ func main() { TunnelHttpPort: *clientTunnelHttpPort, TunnelTcpPort: *clientTunnelTcpPort, TunnelHost: *clientTunnelHost, + CustomDns: *clientCustomDns, + CustomCert: *clientCustomCert, } client.Run(mmarClientConfig) case constants.VERSION_CMD: diff --git a/constants/main.go b/constants/main.go index 8f96cea..51ea7c5 100644 --- a/constants/main.go +++ b/constants/main.go @@ -1,7 +1,7 @@ package constants const ( - MMAR_VERSION = "0.2.2" + MMAR_VERSION = "0.2.6" VERSION_CMD = "version" SERVER_CMD = "server" @@ -12,29 +12,44 @@ const ( TUNNEL_HOST = "mmar.dev" TUNNEL_HTTP_PORT = "443" + MMAR_ENV_VAR_SERVER_HTTP_PORT = "MMAR__SERVER_HTTP_PORT" + MMAR_ENV_VAR_SERVER_TCP_PORT = "MMAR__SERVER_TCP_PORT" + MMAR_ENV_VAR_LOCAL_PORT = "MMAR__LOCAL_PORT" + MMAR_ENV_VAR_TUNNEL_HTTP_PORT = "MMAR__TUNNEL_HTTP_PORT" + MMAR_ENV_VAR_TUNNEL_TCP_PORT = "MMAR__TUNNEL_TCP_PORT" + MMAR_ENV_VAR_TUNNEL_HOST = "MMAR__TUNNEL_HOST" + MMAR_ENV_VAR_CUSTOM_DNS = "MMAR__CUSTOM_DNS" + MMAR_ENV_VAR_CUSTOM_CERT = "MMAR__CUSTOM_CERT" + SERVER_STATS_DEFAULT_USERNAME = "admin" SERVER_STATS_DEFAULT_PASSWORD = "admin" SERVER_HTTP_PORT_HELP = "Define port where mmar will bind to and run on server for HTTP requests." SERVER_TCP_PORT_HELP = "Define port where mmar will bind to and run on server for TCP connections." - CLIENT_LOCAL_PORT_HELP = "Define the port where your local dev server is running to expose through mmar." - CLIENT_HTTP_PORT_HELP = "Define port of mmar HTTP server to make requests through the tunnel." - CLIENT_TCP_PORT_HELP = "Define port of mmar TCP server for client to connect to, creating a tunnel." - TUNNEL_HOST_HELP = "Define host domain of mmar server for client to connect to." + CLIENT_LOCAL_PORT_HELP = "Define the port where your local dev server is running to expose through mmar." + CLIENT_HTTP_PORT_HELP = "Define port of mmar HTTP server to make requests through the tunnel." + CLIENT_TCP_PORT_HELP = "Define port of mmar TCP server for client to connect to, creating a tunnel." + TUNNEL_HOST_HELP = "Define host domain of mmar server for client to connect to." + CLIENT_CUSTOM_DNS_HELP = "Define a custom DNS server that the mmar client should use when accessing your local dev server. (eg: 8.8.8.8:53, defaults to DNS in OS)" + CLIENT_CUSTOM_CERT_HELP = "Define path to file custom TLS certificate containing complete ASN.1 DER content (certificate, signature algorithm and signature). Currently used for testing, but may be used to allow mmar client to work with a dev server using custom TLS certificate setups. (eg: /path/to/cert)" - TUNNEL_MESSAGE_PROTOCOL_VERSION = 1 + TUNNEL_MESSAGE_PROTOCOL_VERSION = 4 TUNNEL_MESSAGE_DATA_DELIMITER = '\n' ID_CHARSET = "abcdefghijklmnopqrstuvwxyz0123456789" ID_LENGTH = 6 - MAX_TUNNELS_PER_IP = 5 - TUNNEL_RECONNECT_TIMEOUT = 3 - GRACEFUL_SHUTDOWN_TIMEOUT = 3 - TUNNEL_CREATE_TIMEOUT = 3 - REQ_BODY_READ_CHUNK_TIMEOUT = 3 - DEST_REQUEST_TIMEOUT = 30 - MAX_REQ_BODY_SIZE = 10000000 // 10mb + MAX_TUNNELS_PER_IP = 5 + TUNNEL_RECONNECT_TIMEOUT = 3 + GRACEFUL_SHUTDOWN_TIMEOUT = 3 + TUNNEL_CREATE_TIMEOUT = 3 + REQ_BODY_READ_CHUNK_TIMEOUT = 3 + DEST_REQUEST_TIMEOUT = 30 + HEARTBEAT_FROM_SERVER_TIMEOUT = 5 + HEARTBEAT_FROM_CLIENT_TIMEOUT = 2 + READ_DEADLINE = 3 + MAX_REQ_BODY_SIZE = 10000000 // 10mb + REQUEST_ID_BUFF_SIZE = 4 CLIENT_DISCONNECT_ERR_TEXT = "Tunnel is closed, cannot connect to mmar client." LOCALHOST_NOT_RUNNING_ERR_TEXT = "Tunneled successfully, but nothing is running on localhost." diff --git a/docs/assets/img/mmar-demo.gif b/docs/assets/img/mmar-demo.gif new file mode 100644 index 0000000..636a18b Binary files /dev/null and b/docs/assets/img/mmar-demo.gif differ diff --git a/install.sh b/install.sh new file mode 100644 index 0000000..20efdd8 --- /dev/null +++ b/install.sh @@ -0,0 +1,52 @@ +#!/bin/sh + +set -e + +REPO="yusuf-musleh/mmar" +BINARY="mmar" + +echo "Installing $BINARY..." + +# Detect OS +OS="$(uname -s)" +case "$OS" in + Linux) OS_TITLE="Linux";; + Darwin) OS_TITLE="Darwin";; + *) echo "Unsupported OS: $OS"; exit 1;; +esac + +# Detect ARCH +ARCH="$(uname -m)" +case "$ARCH" in + x86_64) ARCH_ID="x86_64";; + i386) ARCH_ID="i386";; + aarch64|arm64) ARCH_ID="arm64";; + *) echo "Unsupported architecture: $ARCH"; exit 1;; +esac + +ASSET="${BINARY}_${OS_TITLE}_${ARCH_ID}.tar.gz" +URL="https://github.com/$REPO/releases/latest/download/$ASSET" + +# Temp dir +TMP_DIR=$(mktemp -d) +cd "$TMP_DIR" + +echo "Downloading $ASSET..." +curl -sSL "$URL" -o "$ASSET" + +echo "Extracting..." +tar -xzf "$ASSET" + +# Install location +INSTALL_DIR="/usr/local/bin" + +# Ensure /usr/local/lib exists +if [ ! -d "$INSTALL_DIR" ]; then + sudo mkdir -p "$INSTALL_DIR" +fi + +echo "Installing to $INSTALL_DIR" +sudo install -m 755 "$BINARY" "$INSTALL_DIR/$BINARY" + +echo "$BINARY installed successfully to $INSTALL_DIR/$BINARY" +"$BINARY" version || true diff --git a/internal/client/main.go b/internal/client/main.go index f240ef1..2eb8a92 100644 --- a/internal/client/main.go +++ b/internal/client/main.go @@ -4,6 +4,8 @@ import ( "bufio" "bytes" "context" + "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -26,6 +28,8 @@ type ConfigOptions struct { TunnelHttpPort string TunnelTcpPort string TunnelHost string + CustomDns string + CustomCert string } type MmarClient struct { @@ -58,9 +62,64 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) { }, } + // Use custom DNS if set + if mc.CustomDns != "" { + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial("udp", mc.CustomDns) + }, + } + dialer := &net.Dialer{ + Resolver: r, + } + + tp := &http.Transport{ + DialContext: dialer.DialContext, + } + + fwdClient.Transport = tp + } + + // Use custom TLS certificate if setup + if mc.CustomCert != "" { + certData, certFileErr := os.ReadFile(mc.CustomCert) + if certFileErr != nil { + logger.Log( + constants.RED, + fmt.Sprintf( + "Could not read certificate from file: %v", + certFileErr, + )) + os.Exit(1) + } + + cert, certErr := x509.ParseCertificate(certData) + if certErr != nil { + logger.Log(constants.YELLOW, "Warning: Could not load custom certificate") + } else { + fwdClient.Transport.(*http.Transport).TLSClientConfig = &tls.Config{ + RootCAs: x509.NewCertPool(), + } + fwdClient.Transport.(*http.Transport).TLSClientConfig.RootCAs.AddCert(cert) + } + } + reqReader := bufio.NewReader(bytes.NewReader(tunnelMsg.MsgData)) - req, reqErr := http.ReadRequest(reqReader) + // Extract RequestId + reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE) + _, err := io.ReadFull(reqReader, reqIdBuff) + if err != nil { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse RequestId for request: %v\n", err)) + return + } + + // Include RequestId in tunnel back message + msgData := []byte{} + msgData = append(msgData, reqIdBuff...) + + req, reqErr := http.ReadRequest(reqReader) if reqErr != nil { if errors.Is(reqErr, io.EOF) { logger.Log(constants.DEFAULT_COLOR, "Connection to mmar server closed or disconnected. Exiting...") @@ -80,32 +139,36 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) { resp, fwdErr := fwdClient.Do(req) if fwdErr != nil { if errors.Is(fwdErr, syscall.ECONNREFUSED) || errors.Is(fwdErr, io.ErrUnexpectedEOF) || errors.Is(fwdErr, io.EOF) { - localhostNotRunningMsg := protocol.TunnelMessage{MsgType: protocol.LOCALHOST_NOT_RUNNING} + localhostNotRunningMsg := protocol.TunnelMessage{MsgType: protocol.LOCALHOST_NOT_RUNNING, MsgData: msgData} if err := mc.SendMessage(localhostNotRunningMsg); err != nil { log.Fatal(err) } return } else if errors.Is(fwdErr, context.DeadlineExceeded) { - destServerTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.DEST_REQUEST_TIMEDOUT} + destServerTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.DEST_REQUEST_TIMEDOUT, MsgData: msgData} if err := mc.SendMessage(destServerTimedoutMsg); err != nil { log.Fatal(err) } return } - log.Fatalf("Failed to forward: %v", fwdErr) + invalidRespFromDestMsg := protocol.TunnelMessage{MsgType: protocol.INVALID_RESP_FROM_DEST, MsgData: msgData} + if err := mc.SendMessage(invalidRespFromDestMsg); err != nil { + log.Fatal(err) + } + return } - logger.LogHTTP(req, resp.StatusCode, resp.ContentLength, false, true) - // Writing response to buffer to tunnel it back var responseBuff bytes.Buffer resp.Write(&responseBuff) - - respMessage := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()} + msgData = append(msgData, responseBuff.Bytes()...) + respMessage := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: msgData} if err := mc.SendMessage(respMessage); err != nil { log.Fatal(err) } + + logger.LogHTTP(req, resp.StatusCode, resp.ContentLength, false, true) } // Keep attempting to reconnect the existing tunnel until successful @@ -118,7 +181,7 @@ func (mc *MmarClient) reconnectTunnel(ctx context.Context) { logger.Log(constants.DEFAULT_COLOR, "Attempting to reconnect...") conn, err := net.DialTimeout( "tcp", - fmt.Sprintf("%s:%s", mc.ConfigOptions.TunnelHost, mc.ConfigOptions.TunnelTcpPort), + net.JoinHostPort(mc.ConfigOptions.TunnelHost, mc.ConfigOptions.TunnelTcpPort), constants.TUNNEL_CREATE_TIMEOUT*time.Second, ) if err != nil { @@ -126,6 +189,15 @@ func (mc *MmarClient) reconnectTunnel(ctx context.Context) { continue } mc.Tunnel.Conn = conn + mc.Tunnel.Reader = bufio.NewReader(conn) + + // Try to reclaim the same subdomain + reclaimTunnelMsg := protocol.TunnelMessage{MsgType: protocol.RECLAIM_TUNNEL, MsgData: []byte(mc.subdomain)} + if err := mc.SendMessage(reclaimTunnelMsg); err != nil { + logger.Log(constants.DEFAULT_COLOR, "Tunnel failed to reconnect. Exiting...") + os.Exit(0) + } + break } } @@ -136,13 +208,35 @@ func (mc *MmarClient) ProcessTunnelMessages(ctx context.Context) { case <-ctx.Done(): // Client gracefully shutdown return default: + // Send heartbeat if nothing has been read for a while + receiveMessageTimeout := time.AfterFunc( + constants.HEARTBEAT_FROM_CLIENT_TIMEOUT*time.Second, + func() { + heartbeatMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_FROM_CLIENT} + if err := mc.SendMessage(heartbeatMsg); err != nil { + logger.Log(constants.DEFAULT_COLOR, "Failed to send heartbeat. Exiting...") + os.Exit(0) + } + // Set a read timeout, if no response to heartbeat is recieved within that period, + // attempt to reconnect to the server + readDeadline := time.Now().Add((constants.READ_DEADLINE * time.Second)) + mc.Tunnel.Conn.SetReadDeadline(readDeadline) + }, + ) + tunnelMsg, err := mc.ReceiveMessage() + // If a message is received, stop the receiveMessageTimeout and remove the ReadTimeout + // as we do not need to send heartbeat or check connection health in this iteration + receiveMessageTimeout.Stop() + mc.Tunnel.Conn.SetReadDeadline(time.Time{}) + if err != nil { // If the context was cancelled just return if errors.Is(ctx.Err(), context.Canceled) { return - } else if errors.Is(err, os.ErrDeadlineExceeded) { - continue + } else if errors.Is(err, protocol.INVALID_MESSAGE_PROTOCOL_VERSION) { + logger.Log(constants.YELLOW, "The mmar message protocol has been updated, please update mmar.") + os.Exit(0) } logger.Log(constants.DEFAULT_COLOR, "Tunnel connection disconnected.") @@ -154,21 +248,9 @@ func (mc *MmarClient) ProcessTunnelMessages(ctx context.Context) { } switch tunnelMsg.MsgType { - case protocol.CLIENT_CONNECT: + case protocol.TUNNEL_CREATED, protocol.TUNNEL_RECLAIMED: tunnelSubdomain := string(tunnelMsg.MsgData) - // If there is an existing subdomain, that means we are reconnecting with an - // existing mmar client, try to reclaim the same subdomain - if mc.subdomain != "" { - reconnectMsg := protocol.TunnelMessage{MsgType: protocol.CLIENT_RECLAIM_SUBDOMAIN, MsgData: []byte(tunnelSubdomain + ":" + mc.subdomain)} - mc.subdomain = "" - if err := mc.SendMessage(reconnectMsg); err != nil { - logger.Log(constants.DEFAULT_COLOR, "Tunnel failed to reconnect. Exiting...") - os.Exit(0) - } - continue - } else { - mc.subdomain = tunnelSubdomain - } + mc.subdomain = tunnelSubdomain logger.LogTunnelCreated(tunnelSubdomain, mc.TunnelHost, mc.TunnelHttpPort, mc.LocalPort) case protocol.CLIENT_TUNNEL_LIMIT: limit := logger.ColorLogStr( @@ -184,6 +266,15 @@ func (mc *MmarClient) ProcessTunnelMessages(ctx context.Context) { os.Exit(0) case protocol.REQUEST: go mc.handleRequestMessage(tunnelMsg) + case protocol.HEARTBEAT_ACK: + // Got a heartbeat ack, that means the connection is healthy, + // we do not need to perform any action + case protocol.HEARTBEAT_FROM_SERVER: + heartbeatAckMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_ACK} + if err := mc.SendMessage(heartbeatAckMsg); err != nil { + logger.Log(constants.DEFAULT_COLOR, "Failed to send Heartbeat Ack. Exiting...") + os.Exit(0) + } } } } @@ -198,7 +289,7 @@ func Run(config ConfigOptions) { conn, err := net.DialTimeout( "tcp", - fmt.Sprintf("%s:%s", config.TunnelHost, config.TunnelTcpPort), + net.JoinHostPort(config.TunnelHost, config.TunnelTcpPort), constants.TUNNEL_CREATE_TIMEOUT*time.Second, ) if err != nil { @@ -215,7 +306,7 @@ func Run(config ConfigOptions) { } defer conn.Close() mmarClient := MmarClient{ - protocol.Tunnel{Conn: conn}, + protocol.Tunnel{Conn: conn, Reader: bufio.NewReader(conn)}, config, "", } @@ -226,6 +317,12 @@ func Run(config ConfigOptions) { // Process Tunnel Messages coming from mmar server go mmarClient.ProcessTunnelMessages(ctx) + createTunnelMsg := protocol.TunnelMessage{MsgType: protocol.CREATE_TUNNEL} + if err := mmarClient.SendMessage(createTunnelMsg); err != nil { + logger.Log(constants.DEFAULT_COLOR, "Failed to create Tunnel. Exiting...") + os.Exit(0) + } + // Wait for an interrupt signal, if received, terminate gracefully <-sigInt diff --git a/internal/protocol/main.go b/internal/protocol/main.go index c0b3241..1287734 100644 --- a/internal/protocol/main.go +++ b/internal/protocol/main.go @@ -19,12 +19,18 @@ import ( const ( REQUEST = uint8(iota + 1) RESPONSE - CLIENT_CONNECT - CLIENT_RECLAIM_SUBDOMAIN + CREATE_TUNNEL + RECLAIM_TUNNEL + TUNNEL_CREATED + TUNNEL_RECLAIMED CLIENT_DISCONNECT CLIENT_TUNNEL_LIMIT LOCALHOST_NOT_RUNNING DEST_REQUEST_TIMEDOUT + HEARTBEAT_FROM_CLIENT + HEARTBEAT_FROM_SERVER + HEARTBEAT_ACK + INVALID_RESP_FROM_DEST ) var INVALID_MESSAGE_PROTOCOL_VERSION = errors.New("Invalid Message Protocol Version") @@ -33,7 +39,7 @@ var INVALID_MESSAGE_TYPE = errors.New("Invalid Tunnel Message Type") func isValidTunnelMessageType(mt uint8) (uint8, error) { // Iterate through all the message type, from first to last, checking // if the provided message type matches one of them - for msgType := REQUEST; msgType <= DEST_REQUEST_TIMEDOUT; msgType++ { + for msgType := REQUEST; msgType <= INVALID_RESP_FROM_DEST; msgType++ { if mt == msgType { return msgType, nil } @@ -45,9 +51,10 @@ func isValidTunnelMessageType(mt uint8) (uint8, error) { func TunnelErrState(errState uint8) string { // TODO: Have nicer/more elaborative error messages/pages errStates := map[uint8]string{ - CLIENT_DISCONNECT: constants.CLIENT_DISCONNECT_ERR_TEXT, - LOCALHOST_NOT_RUNNING: constants.LOCALHOST_NOT_RUNNING_ERR_TEXT, - DEST_REQUEST_TIMEDOUT: constants.DEST_REQUEST_TIMEDOUT_ERR_TEXT, + CLIENT_DISCONNECT: constants.CLIENT_DISCONNECT_ERR_TEXT, + LOCALHOST_NOT_RUNNING: constants.LOCALHOST_NOT_RUNNING_ERR_TEXT, + DEST_REQUEST_TIMEDOUT: constants.DEST_REQUEST_TIMEDOUT_ERR_TEXT, + INVALID_RESP_FROM_DEST: constants.READ_RESP_BODY_ERR_TEXT, } fallbackErr := "An error occured while attempting to tunnel." @@ -71,6 +78,7 @@ type Tunnel struct { Id string Conn net.Conn CreatedOn time.Time + Reader *bufio.Reader } type TunnelInterface interface { @@ -173,6 +181,10 @@ func (tm *TunnelMessage) deserializeMessage(reader *bufio.Reader) error { return nil } +func (t *Tunnel) ReservedSubdomain() bool { + return t.Id != "" +} + func (t *Tunnel) SendMessage(tunnelMsg TunnelMessage) error { // Serialize tunnel message data serializedMsg, serializeErr := tunnelMsg.serializeMessage() @@ -184,11 +196,9 @@ func (t *Tunnel) SendMessage(tunnelMsg TunnelMessage) error { } func (t *Tunnel) ReceiveMessage() (TunnelMessage, error) { - msgReader := bufio.NewReader(t.Conn) - // Read and deserialize tunnel message data tunnelMessage := TunnelMessage{} - deserializeErr := tunnelMessage.deserializeMessage(msgReader) + deserializeErr := tunnelMessage.deserializeMessage(t.Reader) return tunnelMessage, deserializeErr } diff --git a/internal/server/main.go b/internal/server/main.go index fee090f..594ba88 100644 --- a/internal/server/main.go +++ b/internal/server/main.go @@ -4,19 +4,18 @@ import ( "bufio" "bytes" "context" + "encoding/binary" "encoding/json" "errors" "fmt" "html" "io" "log" - "math/rand" "net" "net/http" "os" "os/signal" "slices" - "strings" "sync" "time" @@ -44,7 +43,6 @@ type IncomingRequest struct { responseWriter http.ResponseWriter request *http.Request cancel context.CancelCauseFunc - serializedReq []byte ctx context.Context } @@ -53,11 +51,14 @@ type OutgoingResponse struct { body []byte } +type RequestId uint32 + // Tunnel to Client type ClientTunnel struct { protocol.Tunnel - incomingChannel chan IncomingRequest - outgoingChannel chan protocol.TunnelMessage + incomingChannel chan IncomingRequest + outgoingChannel chan protocol.TunnelMessage + inflightRequests *sync.Map } func (ct *ClientTunnel) drainChannels() { @@ -119,6 +120,16 @@ func (ct *ClientTunnel) close(graceful bool) { ) } +// Generate unique request id for incoming request for client +func (ct *ClientTunnel) GenerateUniqueRequestID() RequestId { + var generatedReqId RequestId + + for _, exists := ct.inflightRequests.Load(generatedReqId); exists || generatedReqId == 0; { + generatedReqId = RequestId(GenerateRandomUint32()) + } + return generatedReqId +} + // Serves simple stats for mmar server behind Basic Authentication func (ms *MmarServer) handleServerStats(w http.ResponseWriter, r *http.Request) { // Check Basic Authentication @@ -193,19 +204,33 @@ func (ms *MmarServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Create response channel to receive response for tunneled request respChannel := make(chan OutgoingResponse) - // Tunnel the request - clientTunnel.incomingChannel <- IncomingRequest{ + // Add request to client's inflight requests + reqId := clientTunnel.GenerateUniqueRequestID() + incomingReq := IncomingRequest{ responseChannel: respChannel, responseWriter: w, request: r, cancel: cancel, - serializedReq: serializedRequest, ctx: ctx, } + clientTunnel.inflightRequests.Store(reqId, incomingReq) + + // Construct Request message data + reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE) + binary.LittleEndian.PutUint32(reqIdBuff, uint32(reqId)) + reqMsgData := append(reqIdBuff, serializedRequest...) + + // Tunnel the request to mmar client + reqMessage := protocol.TunnelMessage{MsgType: protocol.REQUEST, MsgData: reqMsgData} + if err := clientTunnel.SendMessage(reqMessage); err != nil { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send Request msg to client: %v", err)) + cancel(FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR) + } select { case <-ctx.Done(): // Request is canceled or Tunnel is closed if context is canceled handleCancel(context.Cause(ctx), w) + clientTunnel.inflightRequests.Delete(reqId) return case resp, _ := <-respChannel: // Await response for tunneled request // Add header to close the connection @@ -220,20 +245,15 @@ func (ms *MmarServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (ms *MmarServer) GenerateUniqueId() string { - reservedIDs := []string{"", "admin", "stats"} +func (ms *MmarServer) GenerateUniqueSubdomain() string { + reservedSubdomains := []string{"", "admin", "stats"} - generatedId := "" - for _, exists := ms.clients[generatedId]; exists || slices.Contains(reservedIDs, generatedId); { - var randSeed *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) - b := make([]byte, constants.ID_LENGTH) - for i := range b { - b[i] = constants.ID_CHARSET[randSeed.Intn(len(constants.ID_CHARSET))] - } - generatedId = string(b) + generatedSubdomain := "" + for _, exists := ms.clients[generatedSubdomain]; exists || slices.Contains(reservedSubdomains, generatedSubdomain); { + generatedSubdomain = GenerateRandomID() } - return generatedId + return generatedSubdomain } func (ms *MmarServer) TunnelLimitedIP(ip string) bool { @@ -247,31 +267,40 @@ func (ms *MmarServer) TunnelLimitedIP(ip string) bool { return len(tunnels) >= constants.MAX_TUNNELS_PER_IP } -func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) { +func (ms *MmarServer) newClientTunnel(tunnel protocol.Tunnel, subdomain string) (*ClientTunnel, error) { // Acquire lock to create new client tunnel data ms.mu.Lock() - // Generate unique ID for client - uniqueId := ms.GenerateUniqueId() - tunnel := protocol.Tunnel{ - Id: uniqueId, - Conn: conn, - CreatedOn: time.Now(), + var uniqueSubdomain string + var msgType uint8 + if subdomain != "" { + uniqueSubdomain = subdomain + msgType = protocol.TUNNEL_RECLAIMED + } else { + // Generate unique subdomain for client if not passed in + uniqueSubdomain = ms.GenerateUniqueSubdomain() + msgType = protocol.TUNNEL_CREATED } + tunnel.Id = uniqueSubdomain + // Create channels to tunnel requests to and recieve responses from incomingChannel := make(chan IncomingRequest) outgoingChannel := make(chan protocol.TunnelMessage) + // Initialize inflight requests map for client tunnel + var inflightRequests sync.Map + // Create client tunnel clientTunnel := ClientTunnel{ tunnel, incomingChannel, outgoingChannel, + &inflightRequests, } // Check if IP reached max tunnel limit - clientIP := utils.ExtractIP(conn.RemoteAddr().String()) + clientIP := utils.ExtractIP(tunnel.Conn.RemoteAddr().String()) limitedIP := ms.TunnelLimitedIP(clientIP) // If so, send limit message to client and close client tunnel if limitedIP { @@ -286,18 +315,18 @@ func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) { } // Add client tunnel to clients - ms.clients[uniqueId] = clientTunnel + ms.clients[uniqueSubdomain] = clientTunnel // Associate tunnel with client IP - ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], uniqueId) + ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], uniqueSubdomain) // Release lock once created ms.mu.Unlock() - // Send unique ID to client - connMessage := protocol.TunnelMessage{MsgType: protocol.CLIENT_CONNECT, MsgData: []byte(uniqueId)} + // Send unique subdomain to client + connMessage := protocol.TunnelMessage{MsgType: msgType, MsgData: []byte(uniqueSubdomain)} if err := clientTunnel.SendMessage(connMessage); err != nil { - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique ID msg to client: %v", err)) + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique subdomain msg to client: %v", err)) return nil, err } @@ -306,32 +335,18 @@ func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) { func (ms *MmarServer) handleTcpConnection(conn net.Conn) { - clientTunnel, err := ms.newClientTunnel(conn) - - if err != nil { - if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) { - // Close the connection when client max tunnels limit reached - conn.Close() - return - } - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to create ClientTunnel: %v", err)) - return + tunnel := protocol.Tunnel{ + Conn: conn, + CreatedOn: time.Now(), + Reader: bufio.NewReader(conn), } - logger.Log( - constants.DEFAULT_COLOR, - fmt.Sprintf( - "[%s] Tunnel created: %s", - clientTunnel.Tunnel.Id, - conn.RemoteAddr().String(), - ), - ) - // Process Tunnel Messages coming from mmar client - go ms.processTunnelMessages(clientTunnel) + go ms.processTunnelMessages(tunnel) +} - // Start goroutine to process tunneled requests - go ms.processTunneledRequestsForClient(clientTunnel) +func (ms *MmarServer) closeTunnel(t *protocol.Tunnel) { + t.Conn.Close() } func (ms *MmarServer) closeClientTunnel(ct *ClientTunnel) { @@ -351,168 +366,216 @@ func (ms *MmarServer) closeClientTunnel(ct *ClientTunnel) { ct.close(true) } -func (ms *MmarServer) processTunneledRequestsForClient(ct *ClientTunnel) { - for { - // Read requests coming in tunnel channel - incomingReq, ok := <-ct.incomingChannel - if !ok { - // Channel closed, client disconencted, shutdown goroutine - return - } +func (ms *MmarServer) closeClientTunnelOrConn(ct *ClientTunnel, t protocol.Tunnel) { - // Forward the request to mmar client - reqMessage := protocol.TunnelMessage{MsgType: protocol.REQUEST, MsgData: incomingReq.serializedReq} - if err := ct.SendMessage(reqMessage); err != nil { - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send Request msg to client: %v", err)) - incomingReq.cancel(FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR) - continue - } + // If client has not reserved subdomain, just close the tcp connection + if !ct.ReservedSubdomain() { + ms.closeTunnel(&t) + return + } - // Wait for response for this request to come back from outgoing channel - respTunnelMsg, ok := <-ct.outgoingChannel - if !ok { - // Channel closed, client disconencted, shutdown goroutine - return - } + ms.closeClientTunnel(ct) +} - // Read response for forwarded request - respReader := bufio.NewReader(bytes.NewReader(respTunnelMsg.MsgData)) - resp, respErr := http.ReadResponse(respReader, incomingReq.request) +func (ms *MmarServer) handleResponseMessages(ct *ClientTunnel, tunnelMsg protocol.TunnelMessage) { + respReader := bufio.NewReader(bytes.NewReader(tunnelMsg.MsgData)) - if respErr != nil { - if errors.Is(respErr, io.ErrUnexpectedEOF) || errors.Is(respErr, net.ErrClosed) { - incomingReq.cancel(CLIENT_DISCONNECTED_ERR) - ms.closeClientTunnel(ct) - return - } - failedReq := fmt.Sprintf("%s - %s%s", incomingReq.request.Method, html.EscapeString(incomingReq.request.URL.Path), incomingReq.request.URL.RawQuery) - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to return response: %v\n\n for req: %v", respErr, failedReq)) - incomingReq.cancel(FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR) - continue - } + // Extract RequestId + reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE) + _, err := io.ReadFull(respReader, reqIdBuff) + if err != nil { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] - Failed to parse RequestId for response: %v\n", ct.Tunnel.Id, err)) + return + } - respBody, respBodyErr := io.ReadAll(resp.Body) - if respBodyErr != nil { - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse response body: %v\n\n", respBodyErr)) - incomingReq.cancel(READ_RESP_BODY_ERR) - continue - } + // Get Inflight Request and remove it from inflight requests + reqId := RequestId(binary.LittleEndian.Uint32(reqIdBuff)) + inflight, loaded := ct.inflightRequests.LoadAndDelete(reqId) + if !loaded { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] Failed to identify inflight request: %v", ct.Tunnel.Id, reqId)) + return + } - // Set headers for response - for hKey, hVal := range resp.Header { - incomingReq.responseWriter.Header().Set(hKey, hVal[0]) - // Add remaining values for header if more than than one exists - for i := 1; i < len(hVal); i++ { - incomingReq.responseWriter.Header().Add(hKey, hVal[i]) - } + inflightRequest, ok := inflight.(IncomingRequest) + if !ok { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] Failed to parse inflight request: %v", ct.Tunnel.Id, reqId)) + return + } + + // Read response for forwarded request + resp, respErr := http.ReadResponse(respReader, inflightRequest.request) + + if respErr != nil { + if errors.Is(respErr, io.ErrUnexpectedEOF) || errors.Is(respErr, net.ErrClosed) { + inflightRequest.cancel(CLIENT_DISCONNECTED_ERR) + ms.closeClientTunnel(ct) + return } + failedReq := fmt.Sprintf("%s - %s%s", inflightRequest.request.Method, html.EscapeString(inflightRequest.request.URL.Path), inflightRequest.request.URL.RawQuery) + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to return response: %v\n\n for req: %v", respErr, failedReq)) + inflightRequest.cancel(FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR) + return + } - // Close response body - resp.Body.Close() + respBody, respBodyErr := io.ReadAll(resp.Body) + if respBodyErr != nil { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse response body: %v\n\n", respBodyErr)) + inflightRequest.cancel(READ_RESP_BODY_ERR) + return + } - select { - case <-incomingReq.ctx.Done(): - // Request is canceled, on to the next request - continue - case incomingReq.responseChannel <- OutgoingResponse{statusCode: resp.StatusCode, body: respBody}: - // Send response data back + defer resp.Body.Close() + + // Set headers for response + for hKey, hVal := range resp.Header { + inflightRequest.responseWriter.Header().Set(hKey, hVal[0]) + // Add remaining values for header if more than than one exists + for i := 1; i < len(hVal); i++ { + inflightRequest.responseWriter.Header().Add(hKey, hVal[i]) } } + + select { + case <-inflightRequest.ctx.Done(): + // Request is canceled, do nothing + return + case inflightRequest.responseChannel <- OutgoingResponse{statusCode: resp.StatusCode, body: respBody}: + // Send response data back + } } -func (ms *MmarServer) processTunnelMessages(ct *ClientTunnel) { +func (ms *MmarServer) processTunnelMessages(t protocol.Tunnel) { + var ct *ClientTunnel for { - tunnelMsg, err := ct.ReceiveMessage() + // Send heartbeat if nothing has been read for a while + receiveMessageTimeout := time.AfterFunc( + constants.HEARTBEAT_FROM_SERVER_TIMEOUT*time.Second, + func() { + heartbeatMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_FROM_SERVER} + if err := t.SendMessage(heartbeatMsg); err != nil { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send heartbeat: %v", err)) + ms.closeClientTunnelOrConn(ct, t) + return + } + // Set a read timeout, if no response to heartbeat is received within that period, + // that means the client has disconnected + readDeadline := time.Now().Add((constants.READ_DEADLINE * time.Second)) + t.Conn.SetReadDeadline(readDeadline) + }, + ) + + tunnelMsg, err := t.ReceiveMessage() + // If a message is received, stop the receiveMessageTimeout and remove the ReadTimeout + // as we do not need to send heartbeat or check connection health in this iteration + receiveMessageTimeout.Stop() + t.Conn.SetReadDeadline(time.Time{}) + if err != nil { logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Receive Message from client tunnel errored: %v", err)) if utils.NetworkError(err) { // If error with connection, stop processing messages - ms.closeClientTunnel(ct) + ms.closeClientTunnelOrConn(ct, t) return } continue } switch tunnelMsg.MsgType { - case protocol.RESPONSE: - ct.outgoingChannel <- tunnelMsg - case protocol.LOCALHOST_NOT_RUNNING: - // Create a response for Tunnel connected but localhost not running - errState := protocol.TunnelErrState(protocol.LOCALHOST_NOT_RUNNING) - resp := http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(errState)), - } + case protocol.CREATE_TUNNEL: + // mmar client requesting new tunnel + ct, err = ms.newClientTunnel(t, "") - // Writing response to buffer to tunnel it back - var responseBuff bytes.Buffer - resp.Write(&responseBuff) - notRunningMsg := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()} - ct.outgoingChannel <- notRunningMsg - case protocol.DEST_REQUEST_TIMEDOUT: - // Create a response for Tunnel connected but localhost took too long to respond - errState := protocol.TunnelErrState(protocol.DEST_REQUEST_TIMEDOUT) - resp := http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(errState)), + if err != nil { + if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) { + // Close the connection when client max tunnels limit reached + t.Conn.Close() + return + } + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to create ClientTunnel: %v", err)) + return } - // Writing response to buffer to tunnel it back - var responseBuff bytes.Buffer - resp.Write(&responseBuff) - destTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()} - ct.outgoingChannel <- destTimedoutMsg - case protocol.CLIENT_DISCONNECT: - ms.closeClientTunnel(ct) - return - case protocol.CLIENT_RECLAIM_SUBDOMAIN: - newAndExistingIDs := strings.Split(string(tunnelMsg.MsgData), ":") - newId := newAndExistingIDs[0] - existingId := newAndExistingIDs[1] + logger.Log( + constants.DEFAULT_COLOR, + fmt.Sprintf( + "[%s] Tunnel created: %s", + ct.Tunnel.Id, + t.Conn.RemoteAddr().String(), + ), + ) + case protocol.RECLAIM_TUNNEL: + // mmar client reclaiming a previously created tunnel + existingId := string(tunnelMsg.MsgData) // Check if the subdomain has already been taken _, ok := ms.clients[existingId] if ok { // if so, close the tunnel, so the user can create a new one - ms.closeClientTunnel(ct) + ms.closeClientTunnelOrConn(ct, t) return } - ct.Tunnel.Id = existingId - - // Add existing client tunnel to clients - ms.clients[existingId] = *ct - - // Remove newId tunnel from clients - delete(ms.clients, newId) - - // Update the tunnels for the IP - clientIP := utils.ExtractIP(ct.Conn.RemoteAddr().String()) - newIdIndex := slices.Index(ms.tunnelsPerIP[clientIP], newId) - if newIdIndex == -1 { - ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], existingId) - } else { - ms.tunnelsPerIP[clientIP][newIdIndex] = existingId - } - - connMessage := protocol.TunnelMessage{MsgType: protocol.CLIENT_CONNECT, MsgData: []byte(existingId)} - if err := ct.SendMessage(connMessage); err != nil { - logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique ID msg to client: %v", err)) - ms.closeClientTunnel(ct) + ct, err = ms.newClientTunnel(t, existingId) + if err != nil { + if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) { + // Close the connection when client max tunnels limit reached + t.Conn.Close() + return + } + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to reclaim ClientTunnel: %v", err)) return } logger.Log( constants.DEFAULT_COLOR, fmt.Sprintf( - "[%s] Tunnel reclaimed: %s -> %s", - newId, - ct.Conn.RemoteAddr().String(), + "[%s] Tunnel reclaimed: %s", existingId, + ct.Conn.RemoteAddr().String(), ), ) + case protocol.RESPONSE: + go ms.handleResponseMessages(ct, tunnelMsg) + case protocol.LOCALHOST_NOT_RUNNING: + // Create a response for Tunnel connected but localhost not running + errState := protocol.TunnelErrState(protocol.LOCALHOST_NOT_RUNNING) + responseBuff := createSerializedServerResp("200 OK", http.StatusOK, errState) + notRunningMsg := protocol.TunnelMessage{ + MsgType: protocol.RESPONSE, + MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...), + } + go ms.handleResponseMessages(ct, notRunningMsg) + case protocol.DEST_REQUEST_TIMEDOUT: + // Create a response for Tunnel connected but localhost took too long to respond + errState := protocol.TunnelErrState(protocol.DEST_REQUEST_TIMEDOUT) + responseBuff := createSerializedServerResp("200 OK", http.StatusOK, errState) + destTimedoutMsg := protocol.TunnelMessage{ + MsgType: protocol.RESPONSE, + MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...), + } + go ms.handleResponseMessages(ct, destTimedoutMsg) + case protocol.CLIENT_DISCONNECT: + ms.closeClientTunnelOrConn(ct, t) + return + case protocol.HEARTBEAT_FROM_CLIENT: + heartbeatAckMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_ACK} + if err := t.SendMessage(heartbeatAckMsg); err != nil { + logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to heartbeat ack to client: %v", err)) + ms.closeClientTunnelOrConn(ct, t) + return + } + case protocol.HEARTBEAT_ACK: + // Got a heartbeat ack, that means the connection is healthy, + // we do not need to perform any action + case protocol.INVALID_RESP_FROM_DEST: + // Create a response for receiving invalid response from destination server + errState := protocol.TunnelErrState(protocol.INVALID_RESP_FROM_DEST) + responseBuff := createSerializedServerResp("500 Internal Server Error", http.StatusInternalServerError, errState) + invalidRespFromDestMsg := protocol.TunnelMessage{ + MsgType: protocol.RESPONSE, + MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...), + } + go ms.handleResponseMessages(ct, invalidRespFromDestMsg) } } } diff --git a/internal/server/utils.go b/internal/server/utils.go index 2fd144b..41d9b68 100644 --- a/internal/server/utils.go +++ b/internal/server/utils.go @@ -3,9 +3,12 @@ package server import ( "bytes" "context" + cryptoRand "crypto/rand" + "encoding/binary" "errors" "fmt" "io" + mathRand "math/rand" "net/http" "strconv" "time" @@ -21,7 +24,7 @@ var MAX_REQ_BODY_SIZE_ERR error = errors.New(constants.MAX_REQ_BODY_SIZE_ERR_TEX var FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR error = errors.New(constants.FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR_TEXT) var FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR error = errors.New(constants.FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR_TEXT) -func responseWith(respText string, w http.ResponseWriter, statusCode int) { +func respondWith(respText string, w http.ResponseWriter, statusCode int) { w.Header().Set("Content-Length", strconv.Itoa(len(respText))) w.Header().Set("Connection", "close") w.WriteHeader(statusCode) @@ -34,15 +37,15 @@ func handleCancel(cause error, w http.ResponseWriter) { // Cancelled, do nothing return case READ_BODY_CHUNK_TIMEOUT_ERR: - responseWith(cause.Error(), w, http.StatusRequestTimeout) + respondWith(cause.Error(), w, http.StatusRequestTimeout) case READ_BODY_CHUNK_ERR, CLIENT_DISCONNECTED_ERR: - responseWith(cause.Error(), w, http.StatusBadRequest) + respondWith(cause.Error(), w, http.StatusBadRequest) case READ_RESP_BODY_ERR: - responseWith(cause.Error(), w, http.StatusInternalServerError) + respondWith(cause.Error(), w, http.StatusInternalServerError) case MAX_REQ_BODY_SIZE_ERR: - responseWith(cause.Error(), w, http.StatusRequestEntityTooLarge) + respondWith(cause.Error(), w, http.StatusRequestEntityTooLarge) case FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR, FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR: - responseWith(cause.Error(), w, http.StatusServiceUnavailable) + respondWith(cause.Error(), w, http.StatusServiceUnavailable) } } @@ -119,3 +122,35 @@ func serializeRequest(ctx context.Context, r *http.Request, cancel context.Cance // Send serialized request through channel serializedRequestChannel <- requestBuff.Bytes() } + +// Create HTTP response sent from mmar server to the end-user client +func createSerializedServerResp(status string, statusCode int, body string) bytes.Buffer { + resp := http.Response{ + Status: status, + StatusCode: statusCode, + Body: io.NopCloser(bytes.NewBufferString(body)), + } + + // Writing response to buffer to tunnel it back + var responseBuff bytes.Buffer + resp.Write(&responseBuff) + + return responseBuff +} + +// Generate a random ID from ID_CHARSET of length ID_LENGTH +func GenerateRandomID() string { + var randSeed *mathRand.Rand = mathRand.New(mathRand.NewSource(time.Now().UnixNano())) + b := make([]byte, constants.ID_LENGTH) + for i := range b { + b[i] = constants.ID_CHARSET[randSeed.Intn(len(constants.ID_CHARSET))] + } + return string(b) +} + +// Generate a random 32-bit unsigned integer +func GenerateRandomUint32() uint32 { + var randomUint32 uint32 + binary.Read(cryptoRand.Reader, binary.BigEndian, &randomUint32) + return randomUint32 +} diff --git a/internal/utils/main.go b/internal/utils/main.go index a35d885..b9aa575 100644 --- a/internal/utils/main.go +++ b/internal/utils/main.go @@ -112,5 +112,14 @@ func NetworkError(err error) bool { return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, net.ErrClosed) || - errors.Is(err, syscall.ECONNRESET) + errors.Is(err, syscall.ECONNRESET) || + errors.Is(err, os.ErrDeadlineExceeded) +} + +func EnvVarOrDefault(envVar string, defaultVal string) string { + envValue, ok := os.LookupEnv(envVar) + if !ok { + return defaultVal + } + return envValue } diff --git a/simulations/devserver/main.go b/simulations/devserver/main.go index d7080ca..973d42a 100644 --- a/simulations/devserver/main.go +++ b/simulations/devserver/main.go @@ -25,11 +25,19 @@ type DevServer struct { *httptest.Server } -func NewDevServer() *DevServer { +func NewDevServer(proto string, addr string) *DevServer { mux := setupMux() + var httpServer *httptest.Server + switch proto { + case "https": + httpServer = httptest.NewTLSServer(mux) + case "http": + httpServer = httptest.NewServer(mux) + } + return &DevServer{ - httptest.NewServer(mux), + httpServer, } } @@ -170,19 +178,27 @@ func handleRedirect(w http.ResponseWriter, r *http.Request) { // Request handler that returns an invalid HTTP response func handleBadResp(w http.ResponseWriter, r *http.Request) { - // Return a response with Content-Length headers that do not match the actual data - respBody, err := json.Marshal(map[string]interface{}{ - "data": "some data", - }) + // Get the underlying connection object + // Assert that w supports Hijacking + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "Hijacking not supported", http.StatusInternalServerError) + return + } + // Hijack the connection + conn, buf, err := hijacker.Hijack() if err != nil { - log.Fatalf("Failed to marshal response for GET: %v", err) + http.Error(w, "Hijacking failed", http.StatusInternalServerError) + return } + defer conn.Close() - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Length", "123") // Content length much larger than actual content - w.WriteHeader(http.StatusOK) - w.Write(respBody) + // Send back an invalid HTTP response + buf.WriteString("some random string\r\n" + + "\r\n" + + "that is not a valid http resp") + buf.Flush() } // Request handler that takes a long time before returning response diff --git a/simulations/simulation_test.go b/simulations/simulation_test.go index 1fc2173..174efec 100644 --- a/simulations/simulation_test.go +++ b/simulations/simulation_test.go @@ -43,7 +43,15 @@ func StartMmarServer(ctx context.Context) { } } -func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort string) { +func StartMmarClient( + ctx context.Context, + urlCh chan string, + localDevServerPort string, + localDevServerHost string, + localDevServerProto string, + customDns string, + customCert string, +) { cmd := exec.CommandContext( ctx, "./mmar", @@ -54,6 +62,26 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort localDevServerPort, ) + if localDevServerHost != "" { + cmd.Args = append(cmd.Args, "--local-host", localDevServerHost) + } + + if localDevServerProto != "" { + cmd.Args = append(cmd.Args, "--local-proto", localDevServerProto) + } + + if customDns != "" { + cmd.Args = append(cmd.Args, "--custom-dns", customDns) + } + + if customCert != "" { + cmd.Args = append(cmd.Args, "--custom-cert", customCert) + } + + cmd.Args = append(cmd.Args, "") + + cmd.Stdout = os.Stdout + // Pipe Stderr To capture logs for extracting the tunnel url pipe, _ := cmd.StderrPipe() @@ -77,7 +105,6 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort tunnelUrl := extractTunnelURL(line) if tunnelUrl != "" { urlCh <- tunnelUrl - break } line, readErr = stdoutReader.ReadString('\n') } @@ -91,9 +118,9 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort } } -func StartLocalDevServer() *devserver.DevServer { - ds := devserver.NewDevServer() - log.Printf("Started local dev server on: http://localhost:%v", ds.Port()) +func StartLocalDevServer(proto string, addr string) *devserver.DevServer { + ds := devserver.NewDevServer(proto, addr) + log.Printf("Started local dev server on: %v://%v:%v", proto, addr, ds.Port()) return ds } @@ -115,12 +142,13 @@ func verifyGetRequestSuccess(t *testing.T, client *http.Client, tunnelUrl string resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-get-request-success"}, } @@ -159,12 +187,13 @@ func verifyGetRequestFail(t *testing.T, client *http.Client, tunnelUrl string, w resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-get-request-fail"}, } @@ -210,12 +239,13 @@ func verifyPostRequestSuccess(t *testing.T, client *http.Client, tunnelUrl strin resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-post-request-success"}, "Content-Length": {strconv.Itoa(len(serializedReqBody))}, } @@ -265,12 +295,13 @@ func verifyPostRequestFail(t *testing.T, client *http.Client, tunnelUrl string, resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-post-request-fail"}, "Content-Length": {strconv.Itoa(len(serializedReqBody))}, } @@ -319,12 +350,13 @@ func verifyRedirectsHandled(t *testing.T, client *http.Client, tunnelUrl string, resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-redirect-request"}, "Referer": {tunnelUrl + "/redirect"}, // Include referer header since it redirects } @@ -378,6 +410,7 @@ func verifyInvalidMethodRequestHandled(t *testing.T, client *http.Client, tunnel expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-invalid-method-request"}, } @@ -450,7 +483,6 @@ func verifyInvalidHttpVersionRequestHandled(t *testing.T, tunnelUrl string, wg * if respErr != nil { t.Errorf("%v: Failed to get response %v", "verifyInvalidHttpVersionRequestHandled", respErr) } - if resp.StatusCode != http.StatusBadRequest { t.Errorf( "%v: resp.StatusCode = %v; want %v", @@ -546,7 +578,6 @@ func verifyContentLengthWithNoBodyRequestHandled(t *testing.T, tunnelUrl string, if respErr != nil { t.Errorf("%v: Failed to get response %v", "verifyContentLengthWithNoBodyRequestHandled", respErr) } - expectedBody := constants.READ_BODY_CHUNK_TIMEOUT_ERR_TEXT expectedResp := expectedResponse{ @@ -580,12 +611,13 @@ func verifyRequestWithLargeBody(t *testing.T, client *http.Client, tunnelUrl str resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedReqHeaders := map[string][]string{ "User-Agent": {"Go-http-client/1.1"}, // Default header in golang client "Accept-Encoding": {"gzip"}, // Default header in golang client + "Connection": {"close"}, "Simulation-Test": {"verify-large-post-request-success"}, "Content-Length": {strconv.Itoa(len(serializedReqBody))}, } @@ -634,7 +666,11 @@ func verifyRequestWithVeryLargeBody(t *testing.T, client *http.Client, tunnelUrl resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + // Check if connection was closed in the middle of writing, that's also valid behavior + if !strings.Contains(respErr.Error(), "write: connection reset by peer") { + t.Errorf("Failed to get response: %v", respErr) + } + return } expectedBody := constants.MAX_REQ_BODY_SIZE_ERR_TEXT @@ -661,7 +697,7 @@ func verifyDevServerReturningInvalidRespHandled(t *testing.T, client *http.Clien resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedBody := constants.READ_RESP_BODY_ERR_TEXT @@ -688,7 +724,7 @@ func verifyDevServerLongRunningReqHandledGradefully(t *testing.T, client *http.C resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedBody := constants.DEST_REQUEST_TIMEDOUT_ERR_TEXT @@ -715,7 +751,7 @@ func verifyDevServerCrashHandledGracefully(t *testing.T, client *http.Client, tu resp, respErr := client.Do(req) if respErr != nil { - log.Printf("Failed to get response: %v", respErr) + t.Errorf("Failed to get response: %v", respErr) } expectedBody := constants.LOCALHOST_NOT_RUNNING_ERR_TEXT @@ -735,19 +771,45 @@ func verifyDevServerCrashHandledGracefully(t *testing.T, client *http.Client, tu func TestSimulation(t *testing.T) { simulationCtx, simulationCancel := context.WithCancel(context.Background()) - localDevServer := StartLocalDevServer() + // Start a local dev server with http + localDevServer := StartLocalDevServer("http", "localhost") defer localDevServer.Close() + // Start a local dev server with https + localDevTLSServer := StartLocalDevServer("https", "example.com") + defer localDevTLSServer.Close() + + // Write cert to file so we are able to pass it into mmar client + certErr := os.WriteFile("./temp-cert", localDevTLSServer.Certificate().Raw, 0644) // 0644 is file permissions + if certErr != nil { + log.Fatal(certErr) + } + go dnsserver.StartDnsServer() go StartMmarServer(simulationCtx) wait := time.NewTimer(2 * time.Second) <-wait.C - clientUrlCh := make(chan string) - go StartMmarClient(simulationCtx, clientUrlCh, localDevServer.Port()) - // Wait for tunnel url - tunnelUrl := <-clientUrlCh + // Start a basic mmar client + basicClientUrlCh := make(chan string) + go StartMmarClient(simulationCtx, basicClientUrlCh, localDevServer.Port(), "", "", "", "") + + // Start another basic mmar client + basicClientUrlCh2 := make(chan string) + go StartMmarClient(simulationCtx, basicClientUrlCh2, localDevServer.Port(), "", "", "", "") + + // Wait for all tunnel urls + mmarClientsCount := 2 + tunnelUrls := []string{} + for range mmarClientsCount { + select { + case tunnelUrl := <-basicClientUrlCh: + tunnelUrls = append(tunnelUrls, tunnelUrl) + case tunnelUrl := <-basicClientUrlCh2: + tunnelUrls = append(tunnelUrls, tunnelUrl) + } + } // Initialize http client client := httpClient() @@ -783,18 +845,27 @@ func TestSimulation(t *testing.T) { verifyContentLengthWithNoBodyRequestHandled, } - for _, simTest := range simulationTests { - wg.Add(1) - go simTest(t, client, tunnelUrl, &wg) - } + // Loop through all tunnel urls and run simulation tests + for _, tunnelUrl := range tunnelUrls { + + for _, simTest := range simulationTests { + wg.Add(1) + go simTest(t, client, tunnelUrl, &wg) + } - for _, manualClientSimTest := range manualClientSimulationTests { - wg.Add(1) - go manualClientSimTest(t, tunnelUrl, &wg) + for _, manualClientSimTest := range manualClientSimulationTests { + wg.Add(1) + go manualClientSimTest(t, tunnelUrl, &wg) + } } wg.Wait() + // Delete cert file + if rmErr := os.Remove("./temp-cert"); rmErr != nil { + log.Fatal(rmErr) + } + // Stop simulation tests simulationCancel() diff --git a/simulations/simulation_utils.go b/simulations/simulation_utils.go index d6ae1bf..0a13fab 100644 --- a/simulations/simulation_utils.go +++ b/simulations/simulation_utils.go @@ -95,7 +95,8 @@ func httpClient() *http.Client { dialer := initCustomDialer() tp := &http.Transport{ - DialContext: dialer.DialContext, + DialContext: dialer.DialContext, + DisableKeepAlives: true, } client := &http.Client{Transport: tp} return client