diff --git a/.github/workflows/commit_lint.yml b/.github/workflows/commit_lint.yml
new file mode 100644
index 0000000..2df66ba
--- /dev/null
+++ b/.github/workflows/commit_lint.yml
@@ -0,0 +1,38 @@
+name: Lint Commit Messages
+
+on:
+ workflow_dispatch:
+ pull_request:
+
+permissions:
+ contents: read
+
+jobs:
+ commitlint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Setup node
+ uses: actions/setup-node@v4
+ with:
+ node-version: lts/*
+
+ - name: Install commitlint
+ run: npm install -D @commitlint/cli @commitlint/config-conventional
+
+ - name: Print versions
+ run: |
+ git --version
+ node --version
+ npm --version
+ npx commitlint --version
+
+ - name: Create default conventional commitlint config
+ run: |
+ echo "module.exports = {extends: ['@commitlint/config-conventional']};" > commitlint.config.js
+
+ - name: Validate PR commits with commitlint
+ run: npx commitlint --from ${{ github.event.pull_request.base.sha }} --to ${{ github.event.pull_request.head.sha }} --verbose
diff --git a/.goreleaser.yaml b/.goreleaser.yaml
index 11f6005..aec6f5c 100644
--- a/.goreleaser.yaml
+++ b/.goreleaser.yaml
@@ -14,6 +14,7 @@ builds:
goarch:
- amd64
- arm64
+ - 386 # Add support for Windows 32-bit (x86)
archives:
- formats: [ 'tar.gz' ]
@@ -36,6 +37,7 @@ changelog:
exclude:
- "^docs:"
- "^test:"
+ - "^chore:"
release:
github:
diff --git a/README.md b/README.md
index eb54fc1..af21345 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,3 @@
-> WARNING: This is still in Alpha stage, not ready for production use yet.
-
@@ -8,21 +6,33 @@
# mmar
-mmar (pronounced "ma-mar") is a zero-dependancy, self-hostable, cross-platform HTTP tunnel that exposes your localhost to the world on a public URL.
+mmar (pronounced "ma-mar") is a zero-dependency, self-hostable, cross-platform HTTP tunnel that exposes your localhost to the world on a public URL.
It allows you to quickly share what you are working on locally with others without the hassle of a full deployment, especially if it is not ready to be shipped.
-
+### Demo
+
+
+
+
+
+
### Key Features
- Super simple to use
-- Utilize "mmar.dev" to tunnel for free on a generated subdomain
+- Provides "mmar.dev" to tunnel for free on a generated subdomain
- Expose multiple ports on different subdomains
- Live logs of requests coming into your localhost server
-- Zero dependancies
+- Zero dependencies
- Self-host your own mmar server to have full control
+### Limitations
+
+- Currently only supports the HTTP protocol, other protocols such as websockets were not tested and will likely not work
+- Requests through mmar are limited to 10mb in size, however this could be made configurable in the future
+- There is a limit of 5 mmar tunnels per IP to avoid abuse, this could also be made configurable in the future
+
### Learn More
The development, implementation and technical details of mmar has all been documented in a [devlog series](https://ymusleh.com/tags/mmar.html). You can read more about it there.
@@ -31,7 +41,15 @@ _p.s. mmar means “corridor” or “pass-through” in Arabic._
## Installation
-### MacOS
+### Linux/MacOS
+
+Install mmar
+
+```sh
+sudo curl -sSL https://raw.githubusercontent.com/yusuf-musleh/mmar/refs/heads/master/install.sh | sh
+```
+
+### MacOS (Homebrew)
Use [Homebrew](https://brew.sh/) to install `mmar` on MacOS:
@@ -50,16 +68,12 @@ brew upgrade yusuf-musleh/mmar-tap/mmar
The fastest way to create a tunnel what is running on your `localhost:8080` using [Docker](https://www.docker.com/) is by running this command:
```
-docker run --rm --network host ghcr.io/yusuf-musleh/mmar:v0.1.6 client --local-port 8080
+docker run --rm --network host ghcr.io/yusuf-musleh/mmar:v0.2.6 client --local-port 8080
```
-### Linux
-
-TBD -- see Docker or Manual installation instructions for now
-
### Windows
-TBD -- see Docker or Manual installation instructions for now
+See Docker or Manual installation instructions
### Manually
@@ -71,7 +85,7 @@ Download a [Release](https://github.com/yusuf-musleh/mmar/releases/) from Github
```
$ mmar version
-mmar version 0.1.6
+mmar version 0.2.1
```
1. Make sure you have your localhost server running on some port (eg: 8080)
1. Run the `mmar` client, pointing it to your localhost port
@@ -111,9 +125,163 @@ Commands:
Run `mmar -h` to get help for a specific command
```
+### Configuring through Environment Variables
+
+You can define the various mmar command flags in environment variables rather than passing them in with the command. Here are the available environment variables along with the corresponding flags:
+
+```
+MMAR__SERVER_HTTP_PORT -> mmar server --http-port
+MMAR__SERVER_TCP_PORT -> mmar server --tcp-port
+MMAR__LOCAL_PORT -> mmar client --local-port
+MMAR__TUNNEL_HTTP_PORT -> mmar client --tunnel-http-port
+MMAR__TUNNEL_TCP_PORT -> mmar client --tunnel-tcp-port
+MMAR__TUNNEL_HOST -> mmar client --tunnel-host
+```
+
## Self-Host
-TBD
+Since everything is open-source, you can easily self-host mmar on your own infrastructure under your own domain.
+
+To deploy mmar on your own VPS using docker, you can do the following:
+
+1. Make sure you have your VPS already provisioned and have ssh access to it.
+1. Make sure you already own a domain and have the apex domain as well as wildcard subdomains pointing towards your VPS's public IP. It should look something like this:
+
+
+ | Type | Host | Value | TTL |
+ | -------- | ------- | --------------- | ------ |
+ | A Record | * | 123.123.123.123 | Auto |
+ | A Record | @ | 123.123.123.123 | Auto |
+
+ This would direct all your tunnel subdomains to your VPS for mmar to handle.
+
+1. Next, make sure you have docker installed on your VPS, and create a `compose.yaml` file and add the mmar server as a service:
+
+ ```yaml
+ services:
+ mmar-server:
+ image: "ghcr.io/yusuf-musleh/mmar:v0.2.3" # <----- make sure to use the mmar's latest version
+ restart: unless-stopped
+ command: server
+ environment:
+ - USERNAME_HASH=[YOUR_SHA256_USERNAME_HASH]
+ - PASSWORD_HASH=[YOUR_SHA256_PASSWORD_HASH]
+ ports:
+ - "3376:3376"
+ - "6673:6673"
+ ```
+
+ The `USERNAME_HASH` and `PASSWORD_HASH` env variables are the hashes of the credentials needed to access the stats page, which can be viewed at `stats.yourdomain.com`. The stats pages returns a json with very basic information about the number of clients connected (i.e. tunnels open) along with a list of the subdomains and when they were created:
+
+ ```json
+ {
+ "connectedClients": [
+ {
+ "createdOn": "2025-03-01T08:01:46Z",
+ "id": "owrwf0"
+ }
+ ],
+ "connectedClientsCount": 1
+ }
+ ```
+
+1. Next, we need to also add a reverse proxy, such as [Nginx](https://nginx.org/) or [Caddy](https://caddyserver.com/), so that requests and TCP connections to your domain are routed accordingly. Since the mmar client communicates with the server using TCP, you need to make sure that the reverse proxy supports routing on TCP, and not just HTTP.
+
+ I highly recommend [Caddy](https://caddyserver.com/) as it also handles obtaining SSL certificates for your wildcard subdomains automatically for you, in addition to having a Layer4 reverse proxy to route TCP connections. To get this functionality we need to include a few additional Caddy modules, the [layer4 module](github.com/mholt/caddy-l4) as well as the [caddy-dns](https://github.com/caddy-dns) module that matches your domain registrar, in my case I am using the [namecheap module](https://github.com/caddy-dns/namecheap) in order to automatically issue SSL certificates for wildcard subdomains.
+
+ To add those modules you can build a new docker image for Caddy including these modules:
+
+ ```
+ FROM caddy:2.9.1-builder AS builder
+
+ RUN xcaddy build \
+ --with github.com/mholt/caddy-l4 \
+ --with github.com/caddy-dns/namecheap
+
+ FROM caddy:2.9.1
+
+ COPY --from=builder /usr/bin/caddy /usr/bin/caddy
+ ```
+
+ Next we need to configure Caddy for our setup, let's define a `Caddyfile`:
+
+ ```
+ # Layer4 reverse proxy TCP connection to TCP port
+ {
+ layer4 {
+ example.com:6673 {
+ route {
+ tls
+ proxy {
+ upstream mmar-server:6673
+ }
+ }
+ }
+ }
+ }
+
+ # Redirect to repo page on Github
+ example.com {
+ redir https://github.com/yusuf-musleh/mmar
+ }
+
+ # Reverse proxy HTTP requests to HTTP port
+ *.example.com {
+ reverse_proxy mmar-server:3376
+ tls {
+ resolvers 1.1.1.1
+ dns namecheap {
+ api_key API_KEY_HERE
+ user USERNAME_HERE
+ api_endpoint https://api.namecheap.com/xml.response
+ client_ip IP_HERE
+ }
+ }
+ }
+ ```
+
+ Now that we have the new Caddy image and we defined out Caddyfile, we just need to update out `compose.yaml` file to start Caddy:
+
+ ```yaml
+ services:
+ caddy:
+ image: custom-caddy:2.9.1
+ restart: unless-stopped
+ ports:
+ - "80:80"
+ - "443:443"
+ - "443:443/udp"
+ volumes:
+ - ./Caddyfile:/etc/caddy/Caddyfile
+ - caddy_data:/data
+ - caddy_config:/config
+ mmar-server:
+ image: "ghcr.io/yusuf-musleh/mmar:v0.2.3" # <----- make sure to use the mmar's latest version
+ restart: unless-stopped
+ command: server
+ environment:
+ - USERNAME_HASH=[YOUR_SHA256_USERNAME_HASH]
+ - PASSWORD_HASH=[YOUR_SHA256_PASSWORD_HASH]
+ ports:
+ - "3376:3376"
+ - "6673:6673"
+
+ volumes:
+ caddy_data:
+ caddy_config:
+
+ ```
+
+ That's it! All you need to do is run `docker compose up -d` and then check the logs to make sure everything is running as expected, `docker compose logs --follow`.
+
+1. To create a tunnel using your self-hosted mmar tunnel run the following command on your local machine:
+
+ ```
+ $ mmar client --tunnel-host example.com --local-port 8080
+ ```
+
+ That should open a mmar tunnel through your self-hosted mmar server pointing towards your `localhost:8080`.
+
## License
diff --git a/cmd/mmar/main.go b/cmd/mmar/main.go
index 87b9778..edac2bb 100644
--- a/cmd/mmar/main.go
+++ b/cmd/mmar/main.go
@@ -14,24 +14,46 @@ import (
func main() {
serverCmd := flag.NewFlagSet(constants.SERVER_CMD, flag.ExitOnError)
serverHttpPort := serverCmd.String(
- "http-port", constants.SERVER_HTTP_PORT, constants.SERVER_HTTP_PORT_HELP,
+ "http-port",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_SERVER_HTTP_PORT, constants.SERVER_HTTP_PORT),
+ constants.SERVER_HTTP_PORT_HELP,
)
serverTcpPort := serverCmd.String(
- "tcp-port", constants.SERVER_TCP_PORT, constants.SERVER_TCP_PORT_HELP,
+ "tcp-port",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_SERVER_TCP_PORT, constants.SERVER_TCP_PORT),
+ constants.SERVER_TCP_PORT_HELP,
)
clientCmd := flag.NewFlagSet(constants.CLIENT_CMD, flag.ExitOnError)
clientLocalPort := clientCmd.String(
- "local-port", constants.CLIENT_LOCAL_PORT, constants.CLIENT_LOCAL_PORT_HELP,
+ "local-port",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_LOCAL_PORT, constants.CLIENT_LOCAL_PORT),
+ constants.CLIENT_LOCAL_PORT_HELP,
)
clientTunnelHttpPort := clientCmd.String(
- "tunnel-http-port", constants.TUNNEL_HTTP_PORT, constants.CLIENT_HTTP_PORT_HELP,
+ "tunnel-http-port",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_TUNNEL_HTTP_PORT, constants.TUNNEL_HTTP_PORT),
+ constants.CLIENT_HTTP_PORT_HELP,
)
clientTunnelTcpPort := clientCmd.String(
- "tunnel-tcp-port", constants.SERVER_TCP_PORT, constants.CLIENT_TCP_PORT_HELP,
+ "tunnel-tcp-port",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_TUNNEL_TCP_PORT, constants.SERVER_TCP_PORT),
+ constants.CLIENT_TCP_PORT_HELP,
)
clientTunnelHost := clientCmd.String(
- "tunnel-host", constants.TUNNEL_HOST, constants.TUNNEL_HOST_HELP,
+ "tunnel-host",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_TUNNEL_HOST, constants.TUNNEL_HOST),
+ constants.TUNNEL_HOST_HELP,
+ )
+ clientCustomDns := clientCmd.String(
+ "custom-dns",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_CUSTOM_DNS, ""),
+ constants.CLIENT_CUSTOM_DNS_HELP,
+ )
+ clientCustomCert := clientCmd.String(
+ "custom-cert",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_CUSTOM_CERT, ""),
+ constants.CLIENT_CUSTOM_CERT_HELP,
)
versionCmd := flag.NewFlagSet(constants.VERSION_CMD, flag.ExitOnError)
@@ -59,6 +81,8 @@ func main() {
TunnelHttpPort: *clientTunnelHttpPort,
TunnelTcpPort: *clientTunnelTcpPort,
TunnelHost: *clientTunnelHost,
+ CustomDns: *clientCustomDns,
+ CustomCert: *clientCustomCert,
}
client.Run(mmarClientConfig)
case constants.VERSION_CMD:
diff --git a/constants/main.go b/constants/main.go
index 19a2bd6..51ea7c5 100644
--- a/constants/main.go
+++ b/constants/main.go
@@ -1,7 +1,7 @@
package constants
const (
- MMAR_VERSION = "0.2.1"
+ MMAR_VERSION = "0.2.6"
VERSION_CMD = "version"
SERVER_CMD = "server"
@@ -12,29 +12,44 @@ const (
TUNNEL_HOST = "mmar.dev"
TUNNEL_HTTP_PORT = "443"
+ MMAR_ENV_VAR_SERVER_HTTP_PORT = "MMAR__SERVER_HTTP_PORT"
+ MMAR_ENV_VAR_SERVER_TCP_PORT = "MMAR__SERVER_TCP_PORT"
+ MMAR_ENV_VAR_LOCAL_PORT = "MMAR__LOCAL_PORT"
+ MMAR_ENV_VAR_TUNNEL_HTTP_PORT = "MMAR__TUNNEL_HTTP_PORT"
+ MMAR_ENV_VAR_TUNNEL_TCP_PORT = "MMAR__TUNNEL_TCP_PORT"
+ MMAR_ENV_VAR_TUNNEL_HOST = "MMAR__TUNNEL_HOST"
+ MMAR_ENV_VAR_CUSTOM_DNS = "MMAR__CUSTOM_DNS"
+ MMAR_ENV_VAR_CUSTOM_CERT = "MMAR__CUSTOM_CERT"
+
SERVER_STATS_DEFAULT_USERNAME = "admin"
SERVER_STATS_DEFAULT_PASSWORD = "admin"
SERVER_HTTP_PORT_HELP = "Define port where mmar will bind to and run on server for HTTP requests."
SERVER_TCP_PORT_HELP = "Define port where mmar will bind to and run on server for TCP connections."
- CLIENT_LOCAL_PORT_HELP = "Define the port where your local dev server is running to expose through mmar."
- CLIENT_HTTP_PORT_HELP = "Define port of mmar HTTP server to make requests through the tunnel."
- CLIENT_TCP_PORT_HELP = "Define port of mmar TCP server for client to connect to, creating a tunnel."
- TUNNEL_HOST_HELP = "Define host domain of mmar server for client to connect to."
+ CLIENT_LOCAL_PORT_HELP = "Define the port where your local dev server is running to expose through mmar."
+ CLIENT_HTTP_PORT_HELP = "Define port of mmar HTTP server to make requests through the tunnel."
+ CLIENT_TCP_PORT_HELP = "Define port of mmar TCP server for client to connect to, creating a tunnel."
+ TUNNEL_HOST_HELP = "Define host domain of mmar server for client to connect to."
+ CLIENT_CUSTOM_DNS_HELP = "Define a custom DNS server that the mmar client should use when accessing your local dev server. (eg: 8.8.8.8:53, defaults to DNS in OS)"
+ CLIENT_CUSTOM_CERT_HELP = "Define path to file custom TLS certificate containing complete ASN.1 DER content (certificate, signature algorithm and signature). Currently used for testing, but may be used to allow mmar client to work with a dev server using custom TLS certificate setups. (eg: /path/to/cert)"
- TUNNEL_MESSAGE_PROTOCOL_VERSION = 1
+ TUNNEL_MESSAGE_PROTOCOL_VERSION = 4
TUNNEL_MESSAGE_DATA_DELIMITER = '\n'
ID_CHARSET = "abcdefghijklmnopqrstuvwxyz0123456789"
ID_LENGTH = 6
- MAX_TUNNELS_PER_IP = 5
- TUNNEL_RECONNECT_TIMEOUT = 3
- GRACEFUL_SHUTDOWN_TIMEOUT = 3
- TUNNEL_CREATE_TIMEOUT = 3
- REQ_BODY_READ_CHUNK_TIMEOUT = 3
- DEST_REQUEST_TIMEOUT = 30
- MAX_REQ_BODY_SIZE = 10000000 // 10mb
+ MAX_TUNNELS_PER_IP = 5
+ TUNNEL_RECONNECT_TIMEOUT = 3
+ GRACEFUL_SHUTDOWN_TIMEOUT = 3
+ TUNNEL_CREATE_TIMEOUT = 3
+ REQ_BODY_READ_CHUNK_TIMEOUT = 3
+ DEST_REQUEST_TIMEOUT = 30
+ HEARTBEAT_FROM_SERVER_TIMEOUT = 5
+ HEARTBEAT_FROM_CLIENT_TIMEOUT = 2
+ READ_DEADLINE = 3
+ MAX_REQ_BODY_SIZE = 10000000 // 10mb
+ REQUEST_ID_BUFF_SIZE = 4
CLIENT_DISCONNECT_ERR_TEXT = "Tunnel is closed, cannot connect to mmar client."
LOCALHOST_NOT_RUNNING_ERR_TEXT = "Tunneled successfully, but nothing is running on localhost."
diff --git a/docs/assets/img/mmar-demo.gif b/docs/assets/img/mmar-demo.gif
new file mode 100644
index 0000000..636a18b
Binary files /dev/null and b/docs/assets/img/mmar-demo.gif differ
diff --git a/install.sh b/install.sh
new file mode 100644
index 0000000..20efdd8
--- /dev/null
+++ b/install.sh
@@ -0,0 +1,52 @@
+#!/bin/sh
+
+set -e
+
+REPO="yusuf-musleh/mmar"
+BINARY="mmar"
+
+echo "Installing $BINARY..."
+
+# Detect OS
+OS="$(uname -s)"
+case "$OS" in
+ Linux) OS_TITLE="Linux";;
+ Darwin) OS_TITLE="Darwin";;
+ *) echo "Unsupported OS: $OS"; exit 1;;
+esac
+
+# Detect ARCH
+ARCH="$(uname -m)"
+case "$ARCH" in
+ x86_64) ARCH_ID="x86_64";;
+ i386) ARCH_ID="i386";;
+ aarch64|arm64) ARCH_ID="arm64";;
+ *) echo "Unsupported architecture: $ARCH"; exit 1;;
+esac
+
+ASSET="${BINARY}_${OS_TITLE}_${ARCH_ID}.tar.gz"
+URL="https://github.com/$REPO/releases/latest/download/$ASSET"
+
+# Temp dir
+TMP_DIR=$(mktemp -d)
+cd "$TMP_DIR"
+
+echo "Downloading $ASSET..."
+curl -sSL "$URL" -o "$ASSET"
+
+echo "Extracting..."
+tar -xzf "$ASSET"
+
+# Install location
+INSTALL_DIR="/usr/local/bin"
+
+# Ensure /usr/local/lib exists
+if [ ! -d "$INSTALL_DIR" ]; then
+ sudo mkdir -p "$INSTALL_DIR"
+fi
+
+echo "Installing to $INSTALL_DIR"
+sudo install -m 755 "$BINARY" "$INSTALL_DIR/$BINARY"
+
+echo "$BINARY installed successfully to $INSTALL_DIR/$BINARY"
+"$BINARY" version || true
diff --git a/internal/client/main.go b/internal/client/main.go
index f240ef1..2eb8a92 100644
--- a/internal/client/main.go
+++ b/internal/client/main.go
@@ -4,6 +4,8 @@ import (
"bufio"
"bytes"
"context"
+ "crypto/tls"
+ "crypto/x509"
"errors"
"fmt"
"io"
@@ -26,6 +28,8 @@ type ConfigOptions struct {
TunnelHttpPort string
TunnelTcpPort string
TunnelHost string
+ CustomDns string
+ CustomCert string
}
type MmarClient struct {
@@ -58,9 +62,64 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) {
},
}
+ // Use custom DNS if set
+ if mc.CustomDns != "" {
+ r := &net.Resolver{
+ PreferGo: true,
+ Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
+ return net.Dial("udp", mc.CustomDns)
+ },
+ }
+ dialer := &net.Dialer{
+ Resolver: r,
+ }
+
+ tp := &http.Transport{
+ DialContext: dialer.DialContext,
+ }
+
+ fwdClient.Transport = tp
+ }
+
+ // Use custom TLS certificate if setup
+ if mc.CustomCert != "" {
+ certData, certFileErr := os.ReadFile(mc.CustomCert)
+ if certFileErr != nil {
+ logger.Log(
+ constants.RED,
+ fmt.Sprintf(
+ "Could not read certificate from file: %v",
+ certFileErr,
+ ))
+ os.Exit(1)
+ }
+
+ cert, certErr := x509.ParseCertificate(certData)
+ if certErr != nil {
+ logger.Log(constants.YELLOW, "Warning: Could not load custom certificate")
+ } else {
+ fwdClient.Transport.(*http.Transport).TLSClientConfig = &tls.Config{
+ RootCAs: x509.NewCertPool(),
+ }
+ fwdClient.Transport.(*http.Transport).TLSClientConfig.RootCAs.AddCert(cert)
+ }
+ }
+
reqReader := bufio.NewReader(bytes.NewReader(tunnelMsg.MsgData))
- req, reqErr := http.ReadRequest(reqReader)
+ // Extract RequestId
+ reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE)
+ _, err := io.ReadFull(reqReader, reqIdBuff)
+ if err != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse RequestId for request: %v\n", err))
+ return
+ }
+
+ // Include RequestId in tunnel back message
+ msgData := []byte{}
+ msgData = append(msgData, reqIdBuff...)
+
+ req, reqErr := http.ReadRequest(reqReader)
if reqErr != nil {
if errors.Is(reqErr, io.EOF) {
logger.Log(constants.DEFAULT_COLOR, "Connection to mmar server closed or disconnected. Exiting...")
@@ -80,32 +139,36 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) {
resp, fwdErr := fwdClient.Do(req)
if fwdErr != nil {
if errors.Is(fwdErr, syscall.ECONNREFUSED) || errors.Is(fwdErr, io.ErrUnexpectedEOF) || errors.Is(fwdErr, io.EOF) {
- localhostNotRunningMsg := protocol.TunnelMessage{MsgType: protocol.LOCALHOST_NOT_RUNNING}
+ localhostNotRunningMsg := protocol.TunnelMessage{MsgType: protocol.LOCALHOST_NOT_RUNNING, MsgData: msgData}
if err := mc.SendMessage(localhostNotRunningMsg); err != nil {
log.Fatal(err)
}
return
} else if errors.Is(fwdErr, context.DeadlineExceeded) {
- destServerTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.DEST_REQUEST_TIMEDOUT}
+ destServerTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.DEST_REQUEST_TIMEDOUT, MsgData: msgData}
if err := mc.SendMessage(destServerTimedoutMsg); err != nil {
log.Fatal(err)
}
return
}
- log.Fatalf("Failed to forward: %v", fwdErr)
+ invalidRespFromDestMsg := protocol.TunnelMessage{MsgType: protocol.INVALID_RESP_FROM_DEST, MsgData: msgData}
+ if err := mc.SendMessage(invalidRespFromDestMsg); err != nil {
+ log.Fatal(err)
+ }
+ return
}
- logger.LogHTTP(req, resp.StatusCode, resp.ContentLength, false, true)
-
// Writing response to buffer to tunnel it back
var responseBuff bytes.Buffer
resp.Write(&responseBuff)
-
- respMessage := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()}
+ msgData = append(msgData, responseBuff.Bytes()...)
+ respMessage := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: msgData}
if err := mc.SendMessage(respMessage); err != nil {
log.Fatal(err)
}
+
+ logger.LogHTTP(req, resp.StatusCode, resp.ContentLength, false, true)
}
// Keep attempting to reconnect the existing tunnel until successful
@@ -118,7 +181,7 @@ func (mc *MmarClient) reconnectTunnel(ctx context.Context) {
logger.Log(constants.DEFAULT_COLOR, "Attempting to reconnect...")
conn, err := net.DialTimeout(
"tcp",
- fmt.Sprintf("%s:%s", mc.ConfigOptions.TunnelHost, mc.ConfigOptions.TunnelTcpPort),
+ net.JoinHostPort(mc.ConfigOptions.TunnelHost, mc.ConfigOptions.TunnelTcpPort),
constants.TUNNEL_CREATE_TIMEOUT*time.Second,
)
if err != nil {
@@ -126,6 +189,15 @@ func (mc *MmarClient) reconnectTunnel(ctx context.Context) {
continue
}
mc.Tunnel.Conn = conn
+ mc.Tunnel.Reader = bufio.NewReader(conn)
+
+ // Try to reclaim the same subdomain
+ reclaimTunnelMsg := protocol.TunnelMessage{MsgType: protocol.RECLAIM_TUNNEL, MsgData: []byte(mc.subdomain)}
+ if err := mc.SendMessage(reclaimTunnelMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, "Tunnel failed to reconnect. Exiting...")
+ os.Exit(0)
+ }
+
break
}
}
@@ -136,13 +208,35 @@ func (mc *MmarClient) ProcessTunnelMessages(ctx context.Context) {
case <-ctx.Done(): // Client gracefully shutdown
return
default:
+ // Send heartbeat if nothing has been read for a while
+ receiveMessageTimeout := time.AfterFunc(
+ constants.HEARTBEAT_FROM_CLIENT_TIMEOUT*time.Second,
+ func() {
+ heartbeatMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_FROM_CLIENT}
+ if err := mc.SendMessage(heartbeatMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, "Failed to send heartbeat. Exiting...")
+ os.Exit(0)
+ }
+ // Set a read timeout, if no response to heartbeat is recieved within that period,
+ // attempt to reconnect to the server
+ readDeadline := time.Now().Add((constants.READ_DEADLINE * time.Second))
+ mc.Tunnel.Conn.SetReadDeadline(readDeadline)
+ },
+ )
+
tunnelMsg, err := mc.ReceiveMessage()
+ // If a message is received, stop the receiveMessageTimeout and remove the ReadTimeout
+ // as we do not need to send heartbeat or check connection health in this iteration
+ receiveMessageTimeout.Stop()
+ mc.Tunnel.Conn.SetReadDeadline(time.Time{})
+
if err != nil {
// If the context was cancelled just return
if errors.Is(ctx.Err(), context.Canceled) {
return
- } else if errors.Is(err, os.ErrDeadlineExceeded) {
- continue
+ } else if errors.Is(err, protocol.INVALID_MESSAGE_PROTOCOL_VERSION) {
+ logger.Log(constants.YELLOW, "The mmar message protocol has been updated, please update mmar.")
+ os.Exit(0)
}
logger.Log(constants.DEFAULT_COLOR, "Tunnel connection disconnected.")
@@ -154,21 +248,9 @@ func (mc *MmarClient) ProcessTunnelMessages(ctx context.Context) {
}
switch tunnelMsg.MsgType {
- case protocol.CLIENT_CONNECT:
+ case protocol.TUNNEL_CREATED, protocol.TUNNEL_RECLAIMED:
tunnelSubdomain := string(tunnelMsg.MsgData)
- // If there is an existing subdomain, that means we are reconnecting with an
- // existing mmar client, try to reclaim the same subdomain
- if mc.subdomain != "" {
- reconnectMsg := protocol.TunnelMessage{MsgType: protocol.CLIENT_RECLAIM_SUBDOMAIN, MsgData: []byte(tunnelSubdomain + ":" + mc.subdomain)}
- mc.subdomain = ""
- if err := mc.SendMessage(reconnectMsg); err != nil {
- logger.Log(constants.DEFAULT_COLOR, "Tunnel failed to reconnect. Exiting...")
- os.Exit(0)
- }
- continue
- } else {
- mc.subdomain = tunnelSubdomain
- }
+ mc.subdomain = tunnelSubdomain
logger.LogTunnelCreated(tunnelSubdomain, mc.TunnelHost, mc.TunnelHttpPort, mc.LocalPort)
case protocol.CLIENT_TUNNEL_LIMIT:
limit := logger.ColorLogStr(
@@ -184,6 +266,15 @@ func (mc *MmarClient) ProcessTunnelMessages(ctx context.Context) {
os.Exit(0)
case protocol.REQUEST:
go mc.handleRequestMessage(tunnelMsg)
+ case protocol.HEARTBEAT_ACK:
+ // Got a heartbeat ack, that means the connection is healthy,
+ // we do not need to perform any action
+ case protocol.HEARTBEAT_FROM_SERVER:
+ heartbeatAckMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_ACK}
+ if err := mc.SendMessage(heartbeatAckMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, "Failed to send Heartbeat Ack. Exiting...")
+ os.Exit(0)
+ }
}
}
}
@@ -198,7 +289,7 @@ func Run(config ConfigOptions) {
conn, err := net.DialTimeout(
"tcp",
- fmt.Sprintf("%s:%s", config.TunnelHost, config.TunnelTcpPort),
+ net.JoinHostPort(config.TunnelHost, config.TunnelTcpPort),
constants.TUNNEL_CREATE_TIMEOUT*time.Second,
)
if err != nil {
@@ -215,7 +306,7 @@ func Run(config ConfigOptions) {
}
defer conn.Close()
mmarClient := MmarClient{
- protocol.Tunnel{Conn: conn},
+ protocol.Tunnel{Conn: conn, Reader: bufio.NewReader(conn)},
config,
"",
}
@@ -226,6 +317,12 @@ func Run(config ConfigOptions) {
// Process Tunnel Messages coming from mmar server
go mmarClient.ProcessTunnelMessages(ctx)
+ createTunnelMsg := protocol.TunnelMessage{MsgType: protocol.CREATE_TUNNEL}
+ if err := mmarClient.SendMessage(createTunnelMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, "Failed to create Tunnel. Exiting...")
+ os.Exit(0)
+ }
+
// Wait for an interrupt signal, if received, terminate gracefully
<-sigInt
diff --git a/internal/protocol/main.go b/internal/protocol/main.go
index c0b3241..1287734 100644
--- a/internal/protocol/main.go
+++ b/internal/protocol/main.go
@@ -19,12 +19,18 @@ import (
const (
REQUEST = uint8(iota + 1)
RESPONSE
- CLIENT_CONNECT
- CLIENT_RECLAIM_SUBDOMAIN
+ CREATE_TUNNEL
+ RECLAIM_TUNNEL
+ TUNNEL_CREATED
+ TUNNEL_RECLAIMED
CLIENT_DISCONNECT
CLIENT_TUNNEL_LIMIT
LOCALHOST_NOT_RUNNING
DEST_REQUEST_TIMEDOUT
+ HEARTBEAT_FROM_CLIENT
+ HEARTBEAT_FROM_SERVER
+ HEARTBEAT_ACK
+ INVALID_RESP_FROM_DEST
)
var INVALID_MESSAGE_PROTOCOL_VERSION = errors.New("Invalid Message Protocol Version")
@@ -33,7 +39,7 @@ var INVALID_MESSAGE_TYPE = errors.New("Invalid Tunnel Message Type")
func isValidTunnelMessageType(mt uint8) (uint8, error) {
// Iterate through all the message type, from first to last, checking
// if the provided message type matches one of them
- for msgType := REQUEST; msgType <= DEST_REQUEST_TIMEDOUT; msgType++ {
+ for msgType := REQUEST; msgType <= INVALID_RESP_FROM_DEST; msgType++ {
if mt == msgType {
return msgType, nil
}
@@ -45,9 +51,10 @@ func isValidTunnelMessageType(mt uint8) (uint8, error) {
func TunnelErrState(errState uint8) string {
// TODO: Have nicer/more elaborative error messages/pages
errStates := map[uint8]string{
- CLIENT_DISCONNECT: constants.CLIENT_DISCONNECT_ERR_TEXT,
- LOCALHOST_NOT_RUNNING: constants.LOCALHOST_NOT_RUNNING_ERR_TEXT,
- DEST_REQUEST_TIMEDOUT: constants.DEST_REQUEST_TIMEDOUT_ERR_TEXT,
+ CLIENT_DISCONNECT: constants.CLIENT_DISCONNECT_ERR_TEXT,
+ LOCALHOST_NOT_RUNNING: constants.LOCALHOST_NOT_RUNNING_ERR_TEXT,
+ DEST_REQUEST_TIMEDOUT: constants.DEST_REQUEST_TIMEDOUT_ERR_TEXT,
+ INVALID_RESP_FROM_DEST: constants.READ_RESP_BODY_ERR_TEXT,
}
fallbackErr := "An error occured while attempting to tunnel."
@@ -71,6 +78,7 @@ type Tunnel struct {
Id string
Conn net.Conn
CreatedOn time.Time
+ Reader *bufio.Reader
}
type TunnelInterface interface {
@@ -173,6 +181,10 @@ func (tm *TunnelMessage) deserializeMessage(reader *bufio.Reader) error {
return nil
}
+func (t *Tunnel) ReservedSubdomain() bool {
+ return t.Id != ""
+}
+
func (t *Tunnel) SendMessage(tunnelMsg TunnelMessage) error {
// Serialize tunnel message data
serializedMsg, serializeErr := tunnelMsg.serializeMessage()
@@ -184,11 +196,9 @@ func (t *Tunnel) SendMessage(tunnelMsg TunnelMessage) error {
}
func (t *Tunnel) ReceiveMessage() (TunnelMessage, error) {
- msgReader := bufio.NewReader(t.Conn)
-
// Read and deserialize tunnel message data
tunnelMessage := TunnelMessage{}
- deserializeErr := tunnelMessage.deserializeMessage(msgReader)
+ deserializeErr := tunnelMessage.deserializeMessage(t.Reader)
return tunnelMessage, deserializeErr
}
diff --git a/internal/server/main.go b/internal/server/main.go
index fee090f..594ba88 100644
--- a/internal/server/main.go
+++ b/internal/server/main.go
@@ -4,19 +4,18 @@ import (
"bufio"
"bytes"
"context"
+ "encoding/binary"
"encoding/json"
"errors"
"fmt"
"html"
"io"
"log"
- "math/rand"
"net"
"net/http"
"os"
"os/signal"
"slices"
- "strings"
"sync"
"time"
@@ -44,7 +43,6 @@ type IncomingRequest struct {
responseWriter http.ResponseWriter
request *http.Request
cancel context.CancelCauseFunc
- serializedReq []byte
ctx context.Context
}
@@ -53,11 +51,14 @@ type OutgoingResponse struct {
body []byte
}
+type RequestId uint32
+
// Tunnel to Client
type ClientTunnel struct {
protocol.Tunnel
- incomingChannel chan IncomingRequest
- outgoingChannel chan protocol.TunnelMessage
+ incomingChannel chan IncomingRequest
+ outgoingChannel chan protocol.TunnelMessage
+ inflightRequests *sync.Map
}
func (ct *ClientTunnel) drainChannels() {
@@ -119,6 +120,16 @@ func (ct *ClientTunnel) close(graceful bool) {
)
}
+// Generate unique request id for incoming request for client
+func (ct *ClientTunnel) GenerateUniqueRequestID() RequestId {
+ var generatedReqId RequestId
+
+ for _, exists := ct.inflightRequests.Load(generatedReqId); exists || generatedReqId == 0; {
+ generatedReqId = RequestId(GenerateRandomUint32())
+ }
+ return generatedReqId
+}
+
// Serves simple stats for mmar server behind Basic Authentication
func (ms *MmarServer) handleServerStats(w http.ResponseWriter, r *http.Request) {
// Check Basic Authentication
@@ -193,19 +204,33 @@ func (ms *MmarServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Create response channel to receive response for tunneled request
respChannel := make(chan OutgoingResponse)
- // Tunnel the request
- clientTunnel.incomingChannel <- IncomingRequest{
+ // Add request to client's inflight requests
+ reqId := clientTunnel.GenerateUniqueRequestID()
+ incomingReq := IncomingRequest{
responseChannel: respChannel,
responseWriter: w,
request: r,
cancel: cancel,
- serializedReq: serializedRequest,
ctx: ctx,
}
+ clientTunnel.inflightRequests.Store(reqId, incomingReq)
+
+ // Construct Request message data
+ reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE)
+ binary.LittleEndian.PutUint32(reqIdBuff, uint32(reqId))
+ reqMsgData := append(reqIdBuff, serializedRequest...)
+
+ // Tunnel the request to mmar client
+ reqMessage := protocol.TunnelMessage{MsgType: protocol.REQUEST, MsgData: reqMsgData}
+ if err := clientTunnel.SendMessage(reqMessage); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send Request msg to client: %v", err))
+ cancel(FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR)
+ }
select {
case <-ctx.Done(): // Request is canceled or Tunnel is closed if context is canceled
handleCancel(context.Cause(ctx), w)
+ clientTunnel.inflightRequests.Delete(reqId)
return
case resp, _ := <-respChannel: // Await response for tunneled request
// Add header to close the connection
@@ -220,20 +245,15 @@ func (ms *MmarServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
-func (ms *MmarServer) GenerateUniqueId() string {
- reservedIDs := []string{"", "admin", "stats"}
+func (ms *MmarServer) GenerateUniqueSubdomain() string {
+ reservedSubdomains := []string{"", "admin", "stats"}
- generatedId := ""
- for _, exists := ms.clients[generatedId]; exists || slices.Contains(reservedIDs, generatedId); {
- var randSeed *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))
- b := make([]byte, constants.ID_LENGTH)
- for i := range b {
- b[i] = constants.ID_CHARSET[randSeed.Intn(len(constants.ID_CHARSET))]
- }
- generatedId = string(b)
+ generatedSubdomain := ""
+ for _, exists := ms.clients[generatedSubdomain]; exists || slices.Contains(reservedSubdomains, generatedSubdomain); {
+ generatedSubdomain = GenerateRandomID()
}
- return generatedId
+ return generatedSubdomain
}
func (ms *MmarServer) TunnelLimitedIP(ip string) bool {
@@ -247,31 +267,40 @@ func (ms *MmarServer) TunnelLimitedIP(ip string) bool {
return len(tunnels) >= constants.MAX_TUNNELS_PER_IP
}
-func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) {
+func (ms *MmarServer) newClientTunnel(tunnel protocol.Tunnel, subdomain string) (*ClientTunnel, error) {
// Acquire lock to create new client tunnel data
ms.mu.Lock()
- // Generate unique ID for client
- uniqueId := ms.GenerateUniqueId()
- tunnel := protocol.Tunnel{
- Id: uniqueId,
- Conn: conn,
- CreatedOn: time.Now(),
+ var uniqueSubdomain string
+ var msgType uint8
+ if subdomain != "" {
+ uniqueSubdomain = subdomain
+ msgType = protocol.TUNNEL_RECLAIMED
+ } else {
+ // Generate unique subdomain for client if not passed in
+ uniqueSubdomain = ms.GenerateUniqueSubdomain()
+ msgType = protocol.TUNNEL_CREATED
}
+ tunnel.Id = uniqueSubdomain
+
// Create channels to tunnel requests to and recieve responses from
incomingChannel := make(chan IncomingRequest)
outgoingChannel := make(chan protocol.TunnelMessage)
+ // Initialize inflight requests map for client tunnel
+ var inflightRequests sync.Map
+
// Create client tunnel
clientTunnel := ClientTunnel{
tunnel,
incomingChannel,
outgoingChannel,
+ &inflightRequests,
}
// Check if IP reached max tunnel limit
- clientIP := utils.ExtractIP(conn.RemoteAddr().String())
+ clientIP := utils.ExtractIP(tunnel.Conn.RemoteAddr().String())
limitedIP := ms.TunnelLimitedIP(clientIP)
// If so, send limit message to client and close client tunnel
if limitedIP {
@@ -286,18 +315,18 @@ func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) {
}
// Add client tunnel to clients
- ms.clients[uniqueId] = clientTunnel
+ ms.clients[uniqueSubdomain] = clientTunnel
// Associate tunnel with client IP
- ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], uniqueId)
+ ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], uniqueSubdomain)
// Release lock once created
ms.mu.Unlock()
- // Send unique ID to client
- connMessage := protocol.TunnelMessage{MsgType: protocol.CLIENT_CONNECT, MsgData: []byte(uniqueId)}
+ // Send unique subdomain to client
+ connMessage := protocol.TunnelMessage{MsgType: msgType, MsgData: []byte(uniqueSubdomain)}
if err := clientTunnel.SendMessage(connMessage); err != nil {
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique ID msg to client: %v", err))
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique subdomain msg to client: %v", err))
return nil, err
}
@@ -306,32 +335,18 @@ func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) {
func (ms *MmarServer) handleTcpConnection(conn net.Conn) {
- clientTunnel, err := ms.newClientTunnel(conn)
-
- if err != nil {
- if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) {
- // Close the connection when client max tunnels limit reached
- conn.Close()
- return
- }
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to create ClientTunnel: %v", err))
- return
+ tunnel := protocol.Tunnel{
+ Conn: conn,
+ CreatedOn: time.Now(),
+ Reader: bufio.NewReader(conn),
}
- logger.Log(
- constants.DEFAULT_COLOR,
- fmt.Sprintf(
- "[%s] Tunnel created: %s",
- clientTunnel.Tunnel.Id,
- conn.RemoteAddr().String(),
- ),
- )
-
// Process Tunnel Messages coming from mmar client
- go ms.processTunnelMessages(clientTunnel)
+ go ms.processTunnelMessages(tunnel)
+}
- // Start goroutine to process tunneled requests
- go ms.processTunneledRequestsForClient(clientTunnel)
+func (ms *MmarServer) closeTunnel(t *protocol.Tunnel) {
+ t.Conn.Close()
}
func (ms *MmarServer) closeClientTunnel(ct *ClientTunnel) {
@@ -351,168 +366,216 @@ func (ms *MmarServer) closeClientTunnel(ct *ClientTunnel) {
ct.close(true)
}
-func (ms *MmarServer) processTunneledRequestsForClient(ct *ClientTunnel) {
- for {
- // Read requests coming in tunnel channel
- incomingReq, ok := <-ct.incomingChannel
- if !ok {
- // Channel closed, client disconencted, shutdown goroutine
- return
- }
+func (ms *MmarServer) closeClientTunnelOrConn(ct *ClientTunnel, t protocol.Tunnel) {
- // Forward the request to mmar client
- reqMessage := protocol.TunnelMessage{MsgType: protocol.REQUEST, MsgData: incomingReq.serializedReq}
- if err := ct.SendMessage(reqMessage); err != nil {
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send Request msg to client: %v", err))
- incomingReq.cancel(FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR)
- continue
- }
+ // If client has not reserved subdomain, just close the tcp connection
+ if !ct.ReservedSubdomain() {
+ ms.closeTunnel(&t)
+ return
+ }
- // Wait for response for this request to come back from outgoing channel
- respTunnelMsg, ok := <-ct.outgoingChannel
- if !ok {
- // Channel closed, client disconencted, shutdown goroutine
- return
- }
+ ms.closeClientTunnel(ct)
+}
- // Read response for forwarded request
- respReader := bufio.NewReader(bytes.NewReader(respTunnelMsg.MsgData))
- resp, respErr := http.ReadResponse(respReader, incomingReq.request)
+func (ms *MmarServer) handleResponseMessages(ct *ClientTunnel, tunnelMsg protocol.TunnelMessage) {
+ respReader := bufio.NewReader(bytes.NewReader(tunnelMsg.MsgData))
- if respErr != nil {
- if errors.Is(respErr, io.ErrUnexpectedEOF) || errors.Is(respErr, net.ErrClosed) {
- incomingReq.cancel(CLIENT_DISCONNECTED_ERR)
- ms.closeClientTunnel(ct)
- return
- }
- failedReq := fmt.Sprintf("%s - %s%s", incomingReq.request.Method, html.EscapeString(incomingReq.request.URL.Path), incomingReq.request.URL.RawQuery)
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to return response: %v\n\n for req: %v", respErr, failedReq))
- incomingReq.cancel(FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR)
- continue
- }
+ // Extract RequestId
+ reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE)
+ _, err := io.ReadFull(respReader, reqIdBuff)
+ if err != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] - Failed to parse RequestId for response: %v\n", ct.Tunnel.Id, err))
+ return
+ }
- respBody, respBodyErr := io.ReadAll(resp.Body)
- if respBodyErr != nil {
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse response body: %v\n\n", respBodyErr))
- incomingReq.cancel(READ_RESP_BODY_ERR)
- continue
- }
+ // Get Inflight Request and remove it from inflight requests
+ reqId := RequestId(binary.LittleEndian.Uint32(reqIdBuff))
+ inflight, loaded := ct.inflightRequests.LoadAndDelete(reqId)
+ if !loaded {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] Failed to identify inflight request: %v", ct.Tunnel.Id, reqId))
+ return
+ }
- // Set headers for response
- for hKey, hVal := range resp.Header {
- incomingReq.responseWriter.Header().Set(hKey, hVal[0])
- // Add remaining values for header if more than than one exists
- for i := 1; i < len(hVal); i++ {
- incomingReq.responseWriter.Header().Add(hKey, hVal[i])
- }
+ inflightRequest, ok := inflight.(IncomingRequest)
+ if !ok {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] Failed to parse inflight request: %v", ct.Tunnel.Id, reqId))
+ return
+ }
+
+ // Read response for forwarded request
+ resp, respErr := http.ReadResponse(respReader, inflightRequest.request)
+
+ if respErr != nil {
+ if errors.Is(respErr, io.ErrUnexpectedEOF) || errors.Is(respErr, net.ErrClosed) {
+ inflightRequest.cancel(CLIENT_DISCONNECTED_ERR)
+ ms.closeClientTunnel(ct)
+ return
}
+ failedReq := fmt.Sprintf("%s - %s%s", inflightRequest.request.Method, html.EscapeString(inflightRequest.request.URL.Path), inflightRequest.request.URL.RawQuery)
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to return response: %v\n\n for req: %v", respErr, failedReq))
+ inflightRequest.cancel(FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR)
+ return
+ }
- // Close response body
- resp.Body.Close()
+ respBody, respBodyErr := io.ReadAll(resp.Body)
+ if respBodyErr != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse response body: %v\n\n", respBodyErr))
+ inflightRequest.cancel(READ_RESP_BODY_ERR)
+ return
+ }
- select {
- case <-incomingReq.ctx.Done():
- // Request is canceled, on to the next request
- continue
- case incomingReq.responseChannel <- OutgoingResponse{statusCode: resp.StatusCode, body: respBody}:
- // Send response data back
+ defer resp.Body.Close()
+
+ // Set headers for response
+ for hKey, hVal := range resp.Header {
+ inflightRequest.responseWriter.Header().Set(hKey, hVal[0])
+ // Add remaining values for header if more than than one exists
+ for i := 1; i < len(hVal); i++ {
+ inflightRequest.responseWriter.Header().Add(hKey, hVal[i])
}
}
+
+ select {
+ case <-inflightRequest.ctx.Done():
+ // Request is canceled, do nothing
+ return
+ case inflightRequest.responseChannel <- OutgoingResponse{statusCode: resp.StatusCode, body: respBody}:
+ // Send response data back
+ }
}
-func (ms *MmarServer) processTunnelMessages(ct *ClientTunnel) {
+func (ms *MmarServer) processTunnelMessages(t protocol.Tunnel) {
+ var ct *ClientTunnel
for {
- tunnelMsg, err := ct.ReceiveMessage()
+ // Send heartbeat if nothing has been read for a while
+ receiveMessageTimeout := time.AfterFunc(
+ constants.HEARTBEAT_FROM_SERVER_TIMEOUT*time.Second,
+ func() {
+ heartbeatMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_FROM_SERVER}
+ if err := t.SendMessage(heartbeatMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send heartbeat: %v", err))
+ ms.closeClientTunnelOrConn(ct, t)
+ return
+ }
+ // Set a read timeout, if no response to heartbeat is received within that period,
+ // that means the client has disconnected
+ readDeadline := time.Now().Add((constants.READ_DEADLINE * time.Second))
+ t.Conn.SetReadDeadline(readDeadline)
+ },
+ )
+
+ tunnelMsg, err := t.ReceiveMessage()
+ // If a message is received, stop the receiveMessageTimeout and remove the ReadTimeout
+ // as we do not need to send heartbeat or check connection health in this iteration
+ receiveMessageTimeout.Stop()
+ t.Conn.SetReadDeadline(time.Time{})
+
if err != nil {
logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Receive Message from client tunnel errored: %v", err))
if utils.NetworkError(err) {
// If error with connection, stop processing messages
- ms.closeClientTunnel(ct)
+ ms.closeClientTunnelOrConn(ct, t)
return
}
continue
}
switch tunnelMsg.MsgType {
- case protocol.RESPONSE:
- ct.outgoingChannel <- tunnelMsg
- case protocol.LOCALHOST_NOT_RUNNING:
- // Create a response for Tunnel connected but localhost not running
- errState := protocol.TunnelErrState(protocol.LOCALHOST_NOT_RUNNING)
- resp := http.Response{
- Status: "200 OK",
- StatusCode: http.StatusOK,
- Body: io.NopCloser(bytes.NewBufferString(errState)),
- }
+ case protocol.CREATE_TUNNEL:
+ // mmar client requesting new tunnel
+ ct, err = ms.newClientTunnel(t, "")
- // Writing response to buffer to tunnel it back
- var responseBuff bytes.Buffer
- resp.Write(&responseBuff)
- notRunningMsg := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()}
- ct.outgoingChannel <- notRunningMsg
- case protocol.DEST_REQUEST_TIMEDOUT:
- // Create a response for Tunnel connected but localhost took too long to respond
- errState := protocol.TunnelErrState(protocol.DEST_REQUEST_TIMEDOUT)
- resp := http.Response{
- Status: "200 OK",
- StatusCode: http.StatusOK,
- Body: io.NopCloser(bytes.NewBufferString(errState)),
+ if err != nil {
+ if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) {
+ // Close the connection when client max tunnels limit reached
+ t.Conn.Close()
+ return
+ }
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to create ClientTunnel: %v", err))
+ return
}
- // Writing response to buffer to tunnel it back
- var responseBuff bytes.Buffer
- resp.Write(&responseBuff)
- destTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()}
- ct.outgoingChannel <- destTimedoutMsg
- case protocol.CLIENT_DISCONNECT:
- ms.closeClientTunnel(ct)
- return
- case protocol.CLIENT_RECLAIM_SUBDOMAIN:
- newAndExistingIDs := strings.Split(string(tunnelMsg.MsgData), ":")
- newId := newAndExistingIDs[0]
- existingId := newAndExistingIDs[1]
+ logger.Log(
+ constants.DEFAULT_COLOR,
+ fmt.Sprintf(
+ "[%s] Tunnel created: %s",
+ ct.Tunnel.Id,
+ t.Conn.RemoteAddr().String(),
+ ),
+ )
+ case protocol.RECLAIM_TUNNEL:
+ // mmar client reclaiming a previously created tunnel
+ existingId := string(tunnelMsg.MsgData)
// Check if the subdomain has already been taken
_, ok := ms.clients[existingId]
if ok {
// if so, close the tunnel, so the user can create a new one
- ms.closeClientTunnel(ct)
+ ms.closeClientTunnelOrConn(ct, t)
return
}
- ct.Tunnel.Id = existingId
-
- // Add existing client tunnel to clients
- ms.clients[existingId] = *ct
-
- // Remove newId tunnel from clients
- delete(ms.clients, newId)
-
- // Update the tunnels for the IP
- clientIP := utils.ExtractIP(ct.Conn.RemoteAddr().String())
- newIdIndex := slices.Index(ms.tunnelsPerIP[clientIP], newId)
- if newIdIndex == -1 {
- ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], existingId)
- } else {
- ms.tunnelsPerIP[clientIP][newIdIndex] = existingId
- }
-
- connMessage := protocol.TunnelMessage{MsgType: protocol.CLIENT_CONNECT, MsgData: []byte(existingId)}
- if err := ct.SendMessage(connMessage); err != nil {
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique ID msg to client: %v", err))
- ms.closeClientTunnel(ct)
+ ct, err = ms.newClientTunnel(t, existingId)
+ if err != nil {
+ if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) {
+ // Close the connection when client max tunnels limit reached
+ t.Conn.Close()
+ return
+ }
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to reclaim ClientTunnel: %v", err))
return
}
logger.Log(
constants.DEFAULT_COLOR,
fmt.Sprintf(
- "[%s] Tunnel reclaimed: %s -> %s",
- newId,
- ct.Conn.RemoteAddr().String(),
+ "[%s] Tunnel reclaimed: %s",
existingId,
+ ct.Conn.RemoteAddr().String(),
),
)
+ case protocol.RESPONSE:
+ go ms.handleResponseMessages(ct, tunnelMsg)
+ case protocol.LOCALHOST_NOT_RUNNING:
+ // Create a response for Tunnel connected but localhost not running
+ errState := protocol.TunnelErrState(protocol.LOCALHOST_NOT_RUNNING)
+ responseBuff := createSerializedServerResp("200 OK", http.StatusOK, errState)
+ notRunningMsg := protocol.TunnelMessage{
+ MsgType: protocol.RESPONSE,
+ MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...),
+ }
+ go ms.handleResponseMessages(ct, notRunningMsg)
+ case protocol.DEST_REQUEST_TIMEDOUT:
+ // Create a response for Tunnel connected but localhost took too long to respond
+ errState := protocol.TunnelErrState(protocol.DEST_REQUEST_TIMEDOUT)
+ responseBuff := createSerializedServerResp("200 OK", http.StatusOK, errState)
+ destTimedoutMsg := protocol.TunnelMessage{
+ MsgType: protocol.RESPONSE,
+ MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...),
+ }
+ go ms.handleResponseMessages(ct, destTimedoutMsg)
+ case protocol.CLIENT_DISCONNECT:
+ ms.closeClientTunnelOrConn(ct, t)
+ return
+ case protocol.HEARTBEAT_FROM_CLIENT:
+ heartbeatAckMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_ACK}
+ if err := t.SendMessage(heartbeatAckMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to heartbeat ack to client: %v", err))
+ ms.closeClientTunnelOrConn(ct, t)
+ return
+ }
+ case protocol.HEARTBEAT_ACK:
+ // Got a heartbeat ack, that means the connection is healthy,
+ // we do not need to perform any action
+ case protocol.INVALID_RESP_FROM_DEST:
+ // Create a response for receiving invalid response from destination server
+ errState := protocol.TunnelErrState(protocol.INVALID_RESP_FROM_DEST)
+ responseBuff := createSerializedServerResp("500 Internal Server Error", http.StatusInternalServerError, errState)
+ invalidRespFromDestMsg := protocol.TunnelMessage{
+ MsgType: protocol.RESPONSE,
+ MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...),
+ }
+ go ms.handleResponseMessages(ct, invalidRespFromDestMsg)
}
}
}
diff --git a/internal/server/utils.go b/internal/server/utils.go
index 7e0fa62..41d9b68 100644
--- a/internal/server/utils.go
+++ b/internal/server/utils.go
@@ -3,9 +3,12 @@ package server
import (
"bytes"
"context"
+ cryptoRand "crypto/rand"
+ "encoding/binary"
"errors"
"fmt"
"io"
+ mathRand "math/rand"
"net/http"
"strconv"
"time"
@@ -21,7 +24,7 @@ var MAX_REQ_BODY_SIZE_ERR error = errors.New(constants.MAX_REQ_BODY_SIZE_ERR_TEX
var FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR error = errors.New(constants.FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR_TEXT)
var FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR error = errors.New(constants.FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR_TEXT)
-func responseWith(respText string, w http.ResponseWriter, statusCode int) {
+func respondWith(respText string, w http.ResponseWriter, statusCode int) {
w.Header().Set("Content-Length", strconv.Itoa(len(respText)))
w.Header().Set("Connection", "close")
w.WriteHeader(statusCode)
@@ -34,15 +37,15 @@ func handleCancel(cause error, w http.ResponseWriter) {
// Cancelled, do nothing
return
case READ_BODY_CHUNK_TIMEOUT_ERR:
- responseWith(cause.Error(), w, http.StatusRequestTimeout)
+ respondWith(cause.Error(), w, http.StatusRequestTimeout)
case READ_BODY_CHUNK_ERR, CLIENT_DISCONNECTED_ERR:
- responseWith(cause.Error(), w, http.StatusBadRequest)
+ respondWith(cause.Error(), w, http.StatusBadRequest)
case READ_RESP_BODY_ERR:
- responseWith(cause.Error(), w, http.StatusInternalServerError)
+ respondWith(cause.Error(), w, http.StatusInternalServerError)
case MAX_REQ_BODY_SIZE_ERR:
- responseWith(cause.Error(), w, http.StatusRequestEntityTooLarge)
+ respondWith(cause.Error(), w, http.StatusRequestEntityTooLarge)
case FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR, FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR:
- responseWith(cause.Error(), w, http.StatusServiceUnavailable)
+ respondWith(cause.Error(), w, http.StatusServiceUnavailable)
}
}
@@ -65,7 +68,7 @@ func serializeRequest(ctx context.Context, r *http.Request, cancel context.Cance
fmt.Sprintf(
"%v %v %v\nHost: %v\n",
r.Method,
- r.URL.Path,
+ r.RequestURI,
r.Proto,
r.Host,
),
@@ -119,3 +122,35 @@ func serializeRequest(ctx context.Context, r *http.Request, cancel context.Cance
// Send serialized request through channel
serializedRequestChannel <- requestBuff.Bytes()
}
+
+// Create HTTP response sent from mmar server to the end-user client
+func createSerializedServerResp(status string, statusCode int, body string) bytes.Buffer {
+ resp := http.Response{
+ Status: status,
+ StatusCode: statusCode,
+ Body: io.NopCloser(bytes.NewBufferString(body)),
+ }
+
+ // Writing response to buffer to tunnel it back
+ var responseBuff bytes.Buffer
+ resp.Write(&responseBuff)
+
+ return responseBuff
+}
+
+// Generate a random ID from ID_CHARSET of length ID_LENGTH
+func GenerateRandomID() string {
+ var randSeed *mathRand.Rand = mathRand.New(mathRand.NewSource(time.Now().UnixNano()))
+ b := make([]byte, constants.ID_LENGTH)
+ for i := range b {
+ b[i] = constants.ID_CHARSET[randSeed.Intn(len(constants.ID_CHARSET))]
+ }
+ return string(b)
+}
+
+// Generate a random 32-bit unsigned integer
+func GenerateRandomUint32() uint32 {
+ var randomUint32 uint32
+ binary.Read(cryptoRand.Reader, binary.BigEndian, &randomUint32)
+ return randomUint32
+}
diff --git a/internal/utils/main.go b/internal/utils/main.go
index a35d885..b9aa575 100644
--- a/internal/utils/main.go
+++ b/internal/utils/main.go
@@ -112,5 +112,14 @@ func NetworkError(err error) bool {
return errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrUnexpectedEOF) ||
errors.Is(err, net.ErrClosed) ||
- errors.Is(err, syscall.ECONNRESET)
+ errors.Is(err, syscall.ECONNRESET) ||
+ errors.Is(err, os.ErrDeadlineExceeded)
+}
+
+func EnvVarOrDefault(envVar string, defaultVal string) string {
+ envValue, ok := os.LookupEnv(envVar)
+ if !ok {
+ return defaultVal
+ }
+ return envValue
}
diff --git a/simulations/devserver/main.go b/simulations/devserver/main.go
index 96741cd..973d42a 100644
--- a/simulations/devserver/main.go
+++ b/simulations/devserver/main.go
@@ -25,11 +25,19 @@ type DevServer struct {
*httptest.Server
}
-func NewDevServer() *DevServer {
+func NewDevServer(proto string, addr string) *DevServer {
mux := setupMux()
+ var httpServer *httptest.Server
+ switch proto {
+ case "https":
+ httpServer = httptest.NewTLSServer(mux)
+ case "http":
+ httpServer = httptest.NewServer(mux)
+ }
+
return &DevServer{
- httptest.NewServer(mux),
+ httpServer,
}
}
@@ -55,12 +63,14 @@ func setupMux() *http.ServeMux {
}
func handleGet(w http.ResponseWriter, r *http.Request) {
- // Include echo of request headers in response to confirm they were received
+ // Include echo of request headers and query params in response to
+ // confirm they were received
respBody, err := json.Marshal(map[string]interface{}{
"success": true,
"data": "some data",
"echo": map[string]interface{}{
- "reqHeaders": r.Header,
+ "reqHeaders": r.Header,
+ "reqQueryParams": r.URL.Query(),
},
})
@@ -168,19 +178,27 @@ func handleRedirect(w http.ResponseWriter, r *http.Request) {
// Request handler that returns an invalid HTTP response
func handleBadResp(w http.ResponseWriter, r *http.Request) {
- // Return a response with Content-Length headers that do not match the actual data
- respBody, err := json.Marshal(map[string]interface{}{
- "data": "some data",
- })
+ // Get the underlying connection object
+ // Assert that w supports Hijacking
+ hijacker, ok := w.(http.Hijacker)
+ if !ok {
+ http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
+ return
+ }
+ // Hijack the connection
+ conn, buf, err := hijacker.Hijack()
if err != nil {
- log.Fatalf("Failed to marshal response for GET: %v", err)
+ http.Error(w, "Hijacking failed", http.StatusInternalServerError)
+ return
}
+ defer conn.Close()
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("Content-Length", "123") // Content length much larger than actual content
- w.WriteHeader(http.StatusOK)
- w.Write(respBody)
+ // Send back an invalid HTTP response
+ buf.WriteString("some random string\r\n" +
+ "\r\n" +
+ "that is not a valid http resp")
+ buf.Flush()
}
// Request handler that takes a long time before returning response
diff --git a/simulations/simulation_test.go b/simulations/simulation_test.go
index 924e13d..174efec 100644
--- a/simulations/simulation_test.go
+++ b/simulations/simulation_test.go
@@ -43,7 +43,15 @@ func StartMmarServer(ctx context.Context) {
}
}
-func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort string) {
+func StartMmarClient(
+ ctx context.Context,
+ urlCh chan string,
+ localDevServerPort string,
+ localDevServerHost string,
+ localDevServerProto string,
+ customDns string,
+ customCert string,
+) {
cmd := exec.CommandContext(
ctx,
"./mmar",
@@ -54,6 +62,26 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort
localDevServerPort,
)
+ if localDevServerHost != "" {
+ cmd.Args = append(cmd.Args, "--local-host", localDevServerHost)
+ }
+
+ if localDevServerProto != "" {
+ cmd.Args = append(cmd.Args, "--local-proto", localDevServerProto)
+ }
+
+ if customDns != "" {
+ cmd.Args = append(cmd.Args, "--custom-dns", customDns)
+ }
+
+ if customCert != "" {
+ cmd.Args = append(cmd.Args, "--custom-cert", customCert)
+ }
+
+ cmd.Args = append(cmd.Args, "")
+
+ cmd.Stdout = os.Stdout
+
// Pipe Stderr To capture logs for extracting the tunnel url
pipe, _ := cmd.StderrPipe()
@@ -77,7 +105,6 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort
tunnelUrl := extractTunnelURL(line)
if tunnelUrl != "" {
urlCh <- tunnelUrl
- break
}
line, readErr = stdoutReader.ReadString('\n')
}
@@ -91,9 +118,9 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort
}
}
-func StartLocalDevServer() *devserver.DevServer {
- ds := devserver.NewDevServer()
- log.Printf("Started local dev server on: http://localhost:%v", ds.Port())
+func StartLocalDevServer(proto string, addr string) *devserver.DevServer {
+ ds := devserver.NewDevServer(proto, addr)
+ log.Printf("Started local dev server on: %v://%v:%v", proto, addr, ds.Port())
return ds
}
@@ -107,14 +134,21 @@ func verifyGetRequestSuccess(t *testing.T, client *http.Client, tunnelUrl string
// Adding custom header to confirm that they are propogated when going through mmar
req.Header.Set("Simulation-Test", "verify-get-request-success")
+ // Adding query params to confirm that they get propogated when going through mmar
+ q := req.URL.Query()
+ q.Add("first", "query param")
+ q.Add("second", "param & last")
+ req.URL.RawQuery = q.Encode()
+
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-get-request-success"},
}
@@ -122,7 +156,8 @@ func verifyGetRequestSuccess(t *testing.T, client *http.Client, tunnelUrl string
"success": true,
"data": "some data",
"echo": map[string]interface{}{
- "reqHeaders": expectedReqHeaders,
+ "reqHeaders": expectedReqHeaders,
+ "reqQueryParams": q,
},
}
marshaledBody, _ := json.Marshal(expectedBody)
@@ -152,12 +187,13 @@ func verifyGetRequestFail(t *testing.T, client *http.Client, tunnelUrl string, w
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-get-request-fail"},
}
@@ -203,12 +239,13 @@ func verifyPostRequestSuccess(t *testing.T, client *http.Client, tunnelUrl strin
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-post-request-success"},
"Content-Length": {strconv.Itoa(len(serializedReqBody))},
}
@@ -258,12 +295,13 @@ func verifyPostRequestFail(t *testing.T, client *http.Client, tunnelUrl string,
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-post-request-fail"},
"Content-Length": {strconv.Itoa(len(serializedReqBody))},
}
@@ -312,12 +350,13 @@ func verifyRedirectsHandled(t *testing.T, client *http.Client, tunnelUrl string,
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-redirect-request"},
"Referer": {tunnelUrl + "/redirect"}, // Include referer header since it redirects
}
@@ -326,7 +365,8 @@ func verifyRedirectsHandled(t *testing.T, client *http.Client, tunnelUrl string,
"success": true,
"data": "some data",
"echo": map[string]interface{}{
- "reqHeaders": expectedReqHeaders,
+ "reqHeaders": expectedReqHeaders,
+ "reqQueryParams": map[string][]string{},
},
}
marshaledBody, _ := json.Marshal(expectedBody)
@@ -370,6 +410,7 @@ func verifyInvalidMethodRequestHandled(t *testing.T, client *http.Client, tunnel
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-invalid-method-request"},
}
@@ -377,7 +418,8 @@ func verifyInvalidMethodRequestHandled(t *testing.T, client *http.Client, tunnel
"success": true,
"data": "some data",
"echo": map[string]interface{}{
- "reqHeaders": expectedReqHeaders,
+ "reqHeaders": expectedReqHeaders,
+ "reqQueryParams": map[string][]string{},
},
}
marshaledBody, _ := json.Marshal(expectedBody)
@@ -441,7 +483,6 @@ func verifyInvalidHttpVersionRequestHandled(t *testing.T, tunnelUrl string, wg *
if respErr != nil {
t.Errorf("%v: Failed to get response %v", "verifyInvalidHttpVersionRequestHandled", respErr)
}
-
if resp.StatusCode != http.StatusBadRequest {
t.Errorf(
"%v: resp.StatusCode = %v; want %v",
@@ -537,7 +578,6 @@ func verifyContentLengthWithNoBodyRequestHandled(t *testing.T, tunnelUrl string,
if respErr != nil {
t.Errorf("%v: Failed to get response %v", "verifyContentLengthWithNoBodyRequestHandled", respErr)
}
-
expectedBody := constants.READ_BODY_CHUNK_TIMEOUT_ERR_TEXT
expectedResp := expectedResponse{
@@ -571,12 +611,13 @@ func verifyRequestWithLargeBody(t *testing.T, client *http.Client, tunnelUrl str
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-large-post-request-success"},
"Content-Length": {strconv.Itoa(len(serializedReqBody))},
}
@@ -625,7 +666,11 @@ func verifyRequestWithVeryLargeBody(t *testing.T, client *http.Client, tunnelUrl
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ // Check if connection was closed in the middle of writing, that's also valid behavior
+ if !strings.Contains(respErr.Error(), "write: connection reset by peer") {
+ t.Errorf("Failed to get response: %v", respErr)
+ }
+ return
}
expectedBody := constants.MAX_REQ_BODY_SIZE_ERR_TEXT
@@ -652,7 +697,7 @@ func verifyDevServerReturningInvalidRespHandled(t *testing.T, client *http.Clien
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedBody := constants.READ_RESP_BODY_ERR_TEXT
@@ -679,7 +724,7 @@ func verifyDevServerLongRunningReqHandledGradefully(t *testing.T, client *http.C
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedBody := constants.DEST_REQUEST_TIMEDOUT_ERR_TEXT
@@ -706,7 +751,7 @@ func verifyDevServerCrashHandledGracefully(t *testing.T, client *http.Client, tu
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedBody := constants.LOCALHOST_NOT_RUNNING_ERR_TEXT
@@ -726,19 +771,45 @@ func verifyDevServerCrashHandledGracefully(t *testing.T, client *http.Client, tu
func TestSimulation(t *testing.T) {
simulationCtx, simulationCancel := context.WithCancel(context.Background())
- localDevServer := StartLocalDevServer()
+ // Start a local dev server with http
+ localDevServer := StartLocalDevServer("http", "localhost")
defer localDevServer.Close()
+ // Start a local dev server with https
+ localDevTLSServer := StartLocalDevServer("https", "example.com")
+ defer localDevTLSServer.Close()
+
+ // Write cert to file so we are able to pass it into mmar client
+ certErr := os.WriteFile("./temp-cert", localDevTLSServer.Certificate().Raw, 0644) // 0644 is file permissions
+ if certErr != nil {
+ log.Fatal(certErr)
+ }
+
go dnsserver.StartDnsServer()
go StartMmarServer(simulationCtx)
wait := time.NewTimer(2 * time.Second)
<-wait.C
- clientUrlCh := make(chan string)
- go StartMmarClient(simulationCtx, clientUrlCh, localDevServer.Port())
- // Wait for tunnel url
- tunnelUrl := <-clientUrlCh
+ // Start a basic mmar client
+ basicClientUrlCh := make(chan string)
+ go StartMmarClient(simulationCtx, basicClientUrlCh, localDevServer.Port(), "", "", "", "")
+
+ // Start another basic mmar client
+ basicClientUrlCh2 := make(chan string)
+ go StartMmarClient(simulationCtx, basicClientUrlCh2, localDevServer.Port(), "", "", "", "")
+
+ // Wait for all tunnel urls
+ mmarClientsCount := 2
+ tunnelUrls := []string{}
+ for range mmarClientsCount {
+ select {
+ case tunnelUrl := <-basicClientUrlCh:
+ tunnelUrls = append(tunnelUrls, tunnelUrl)
+ case tunnelUrl := <-basicClientUrlCh2:
+ tunnelUrls = append(tunnelUrls, tunnelUrl)
+ }
+ }
// Initialize http client
client := httpClient()
@@ -774,18 +845,27 @@ func TestSimulation(t *testing.T) {
verifyContentLengthWithNoBodyRequestHandled,
}
- for _, simTest := range simulationTests {
- wg.Add(1)
- go simTest(t, client, tunnelUrl, &wg)
- }
+ // Loop through all tunnel urls and run simulation tests
+ for _, tunnelUrl := range tunnelUrls {
+
+ for _, simTest := range simulationTests {
+ wg.Add(1)
+ go simTest(t, client, tunnelUrl, &wg)
+ }
- for _, manualClientSimTest := range manualClientSimulationTests {
- wg.Add(1)
- go manualClientSimTest(t, tunnelUrl, &wg)
+ for _, manualClientSimTest := range manualClientSimulationTests {
+ wg.Add(1)
+ go manualClientSimTest(t, tunnelUrl, &wg)
+ }
}
wg.Wait()
+ // Delete cert file
+ if rmErr := os.Remove("./temp-cert"); rmErr != nil {
+ log.Fatal(rmErr)
+ }
+
// Stop simulation tests
simulationCancel()
diff --git a/simulations/simulation_utils.go b/simulations/simulation_utils.go
index d6ae1bf..0a13fab 100644
--- a/simulations/simulation_utils.go
+++ b/simulations/simulation_utils.go
@@ -95,7 +95,8 @@ func httpClient() *http.Client {
dialer := initCustomDialer()
tp := &http.Transport{
- DialContext: dialer.DialContext,
+ DialContext: dialer.DialContext,
+ DisableKeepAlives: true,
}
client := &http.Client{Transport: tp}
return client