diff --git a/.github/workflows/commit_lint.yml b/.github/workflows/commit_lint.yml
new file mode 100644
index 0000000..2df66ba
--- /dev/null
+++ b/.github/workflows/commit_lint.yml
@@ -0,0 +1,38 @@
+name: Lint Commit Messages
+
+on:
+ workflow_dispatch:
+ pull_request:
+
+permissions:
+ contents: read
+
+jobs:
+ commitlint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Setup node
+ uses: actions/setup-node@v4
+ with:
+ node-version: lts/*
+
+ - name: Install commitlint
+ run: npm install -D @commitlint/cli @commitlint/config-conventional
+
+ - name: Print versions
+ run: |
+ git --version
+ node --version
+ npm --version
+ npx commitlint --version
+
+ - name: Create default conventional commitlint config
+ run: |
+ echo "module.exports = {extends: ['@commitlint/config-conventional']};" > commitlint.config.js
+
+ - name: Validate PR commits with commitlint
+ run: npx commitlint --from ${{ github.event.pull_request.base.sha }} --to ${{ github.event.pull_request.head.sha }} --verbose
diff --git a/.goreleaser.yaml b/.goreleaser.yaml
index 11f6005..aec6f5c 100644
--- a/.goreleaser.yaml
+++ b/.goreleaser.yaml
@@ -14,6 +14,7 @@ builds:
goarch:
- amd64
- arm64
+ - 386 # Add support for Windows 32-bit (x86)
archives:
- formats: [ 'tar.gz' ]
@@ -36,6 +37,7 @@ changelog:
exclude:
- "^docs:"
- "^test:"
+ - "^chore:"
release:
github:
diff --git a/README.md b/README.md
index ccacfe3..af21345 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,3 @@
-> WARNING: This is still in Alpha stage, not ready for production use yet.
-
@@ -8,21 +6,33 @@
# mmar
-mmar (pronounced "ma-mar") is a zero-dependancy, self-hostable, cross-platform HTTP tunnel that exposes your localhost to the world on a public URL.
+mmar (pronounced "ma-mar") is a zero-dependency, self-hostable, cross-platform HTTP tunnel that exposes your localhost to the world on a public URL.
It allows you to quickly share what you are working on locally with others without the hassle of a full deployment, especially if it is not ready to be shipped.
-
+### Demo
+
+
+
+
+
+
### Key Features
- Super simple to use
-- Utilize "mmar.dev" to tunnel for free on a generated subdomain
+- Provides "mmar.dev" to tunnel for free on a generated subdomain
- Expose multiple ports on different subdomains
- Live logs of requests coming into your localhost server
-- Zero dependancies
+- Zero dependencies
- Self-host your own mmar server to have full control
+### Limitations
+
+- Currently only supports the HTTP protocol, other protocols such as websockets were not tested and will likely not work
+- Requests through mmar are limited to 10mb in size, however this could be made configurable in the future
+- There is a limit of 5 mmar tunnels per IP to avoid abuse, this could also be made configurable in the future
+
### Learn More
The development, implementation and technical details of mmar has all been documented in a [devlog series](https://ymusleh.com/tags/mmar.html). You can read more about it there.
@@ -31,7 +41,15 @@ _p.s. mmar means “corridor” or “pass-through” in Arabic._
## Installation
-### MacOS
+### Linux/MacOS
+
+Install mmar
+
+```sh
+sudo curl -sSL https://raw.githubusercontent.com/yusuf-musleh/mmar/refs/heads/master/install.sh | sh
+```
+
+### MacOS (Homebrew)
Use [Homebrew](https://brew.sh/) to install `mmar` on MacOS:
@@ -50,13 +68,9 @@ brew upgrade yusuf-musleh/mmar-tap/mmar
The fastest way to create a tunnel what is running on your `localhost:8080` using [Docker](https://www.docker.com/) is by running this command:
```
-docker run --rm --network host ghcr.io/yusuf-musleh/mmar:v0.2.2 client --local-port 8080
+docker run --rm --network host ghcr.io/yusuf-musleh/mmar:v0.2.6 client --local-port 8080
```
-### Linux
-
-See Docker or Manual installation instructions
-
### Windows
See Docker or Manual installation instructions
@@ -111,6 +125,19 @@ Commands:
Run `mmar -h` to get help for a specific command
```
+### Configuring through Environment Variables
+
+You can define the various mmar command flags in environment variables rather than passing them in with the command. Here are the available environment variables along with the corresponding flags:
+
+```
+MMAR__SERVER_HTTP_PORT -> mmar server --http-port
+MMAR__SERVER_TCP_PORT -> mmar server --tcp-port
+MMAR__LOCAL_PORT -> mmar client --local-port
+MMAR__TUNNEL_HTTP_PORT -> mmar client --tunnel-http-port
+MMAR__TUNNEL_TCP_PORT -> mmar client --tunnel-tcp-port
+MMAR__TUNNEL_HOST -> mmar client --tunnel-host
+```
+
## Self-Host
Since everything is open-source, you can easily self-host mmar on your own infrastructure under your own domain.
@@ -133,7 +160,7 @@ To deploy mmar on your own VPS using docker, you can do the following:
```yaml
services:
mmar-server:
- image: "ghcr.io/yusuf-musleh/mmar:v0.2.2" # <----- make sure to use the mmar's latest version
+ image: "ghcr.io/yusuf-musleh/mmar:v0.2.3" # <----- make sure to use the mmar's latest version
restart: unless-stopped
command: server
environment:
@@ -144,6 +171,20 @@ To deploy mmar on your own VPS using docker, you can do the following:
- "6673:6673"
```
+ The `USERNAME_HASH` and `PASSWORD_HASH` env variables are the hashes of the credentials needed to access the stats page, which can be viewed at `stats.yourdomain.com`. The stats pages returns a json with very basic information about the number of clients connected (i.e. tunnels open) along with a list of the subdomains and when they were created:
+
+ ```json
+ {
+ "connectedClients": [
+ {
+ "createdOn": "2025-03-01T08:01:46Z",
+ "id": "owrwf0"
+ }
+ ],
+ "connectedClientsCount": 1
+ }
+ ```
+
1. Next, we need to also add a reverse proxy, such as [Nginx](https://nginx.org/) or [Caddy](https://caddyserver.com/), so that requests and TCP connections to your domain are routed accordingly. Since the mmar client communicates with the server using TCP, you need to make sure that the reverse proxy supports routing on TCP, and not just HTTP.
I highly recommend [Caddy](https://caddyserver.com/) as it also handles obtaining SSL certificates for your wildcard subdomains automatically for you, in addition to having a Layer4 reverse proxy to route TCP connections. To get this functionality we need to include a few additional Caddy modules, the [layer4 module](github.com/mholt/caddy-l4) as well as the [caddy-dns](https://github.com/caddy-dns) module that matches your domain registrar, in my case I am using the [namecheap module](https://github.com/caddy-dns/namecheap) in order to automatically issue SSL certificates for wildcard subdomains.
@@ -215,7 +256,7 @@ To deploy mmar on your own VPS using docker, you can do the following:
- caddy_data:/data
- caddy_config:/config
mmar-server:
- image: "ghcr.io/yusuf-musleh/mmar:v0.2.2" # <----- make sure to use the mmar's latest version
+ image: "ghcr.io/yusuf-musleh/mmar:v0.2.3" # <----- make sure to use the mmar's latest version
restart: unless-stopped
command: server
environment:
diff --git a/cmd/mmar/main.go b/cmd/mmar/main.go
index 87b9778..edac2bb 100644
--- a/cmd/mmar/main.go
+++ b/cmd/mmar/main.go
@@ -14,24 +14,46 @@ import (
func main() {
serverCmd := flag.NewFlagSet(constants.SERVER_CMD, flag.ExitOnError)
serverHttpPort := serverCmd.String(
- "http-port", constants.SERVER_HTTP_PORT, constants.SERVER_HTTP_PORT_HELP,
+ "http-port",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_SERVER_HTTP_PORT, constants.SERVER_HTTP_PORT),
+ constants.SERVER_HTTP_PORT_HELP,
)
serverTcpPort := serverCmd.String(
- "tcp-port", constants.SERVER_TCP_PORT, constants.SERVER_TCP_PORT_HELP,
+ "tcp-port",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_SERVER_TCP_PORT, constants.SERVER_TCP_PORT),
+ constants.SERVER_TCP_PORT_HELP,
)
clientCmd := flag.NewFlagSet(constants.CLIENT_CMD, flag.ExitOnError)
clientLocalPort := clientCmd.String(
- "local-port", constants.CLIENT_LOCAL_PORT, constants.CLIENT_LOCAL_PORT_HELP,
+ "local-port",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_LOCAL_PORT, constants.CLIENT_LOCAL_PORT),
+ constants.CLIENT_LOCAL_PORT_HELP,
)
clientTunnelHttpPort := clientCmd.String(
- "tunnel-http-port", constants.TUNNEL_HTTP_PORT, constants.CLIENT_HTTP_PORT_HELP,
+ "tunnel-http-port",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_TUNNEL_HTTP_PORT, constants.TUNNEL_HTTP_PORT),
+ constants.CLIENT_HTTP_PORT_HELP,
)
clientTunnelTcpPort := clientCmd.String(
- "tunnel-tcp-port", constants.SERVER_TCP_PORT, constants.CLIENT_TCP_PORT_HELP,
+ "tunnel-tcp-port",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_TUNNEL_TCP_PORT, constants.SERVER_TCP_PORT),
+ constants.CLIENT_TCP_PORT_HELP,
)
clientTunnelHost := clientCmd.String(
- "tunnel-host", constants.TUNNEL_HOST, constants.TUNNEL_HOST_HELP,
+ "tunnel-host",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_TUNNEL_HOST, constants.TUNNEL_HOST),
+ constants.TUNNEL_HOST_HELP,
+ )
+ clientCustomDns := clientCmd.String(
+ "custom-dns",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_CUSTOM_DNS, ""),
+ constants.CLIENT_CUSTOM_DNS_HELP,
+ )
+ clientCustomCert := clientCmd.String(
+ "custom-cert",
+ utils.EnvVarOrDefault(constants.MMAR_ENV_VAR_CUSTOM_CERT, ""),
+ constants.CLIENT_CUSTOM_CERT_HELP,
)
versionCmd := flag.NewFlagSet(constants.VERSION_CMD, flag.ExitOnError)
@@ -59,6 +81,8 @@ func main() {
TunnelHttpPort: *clientTunnelHttpPort,
TunnelTcpPort: *clientTunnelTcpPort,
TunnelHost: *clientTunnelHost,
+ CustomDns: *clientCustomDns,
+ CustomCert: *clientCustomCert,
}
client.Run(mmarClientConfig)
case constants.VERSION_CMD:
diff --git a/constants/main.go b/constants/main.go
index 8f96cea..51ea7c5 100644
--- a/constants/main.go
+++ b/constants/main.go
@@ -1,7 +1,7 @@
package constants
const (
- MMAR_VERSION = "0.2.2"
+ MMAR_VERSION = "0.2.6"
VERSION_CMD = "version"
SERVER_CMD = "server"
@@ -12,29 +12,44 @@ const (
TUNNEL_HOST = "mmar.dev"
TUNNEL_HTTP_PORT = "443"
+ MMAR_ENV_VAR_SERVER_HTTP_PORT = "MMAR__SERVER_HTTP_PORT"
+ MMAR_ENV_VAR_SERVER_TCP_PORT = "MMAR__SERVER_TCP_PORT"
+ MMAR_ENV_VAR_LOCAL_PORT = "MMAR__LOCAL_PORT"
+ MMAR_ENV_VAR_TUNNEL_HTTP_PORT = "MMAR__TUNNEL_HTTP_PORT"
+ MMAR_ENV_VAR_TUNNEL_TCP_PORT = "MMAR__TUNNEL_TCP_PORT"
+ MMAR_ENV_VAR_TUNNEL_HOST = "MMAR__TUNNEL_HOST"
+ MMAR_ENV_VAR_CUSTOM_DNS = "MMAR__CUSTOM_DNS"
+ MMAR_ENV_VAR_CUSTOM_CERT = "MMAR__CUSTOM_CERT"
+
SERVER_STATS_DEFAULT_USERNAME = "admin"
SERVER_STATS_DEFAULT_PASSWORD = "admin"
SERVER_HTTP_PORT_HELP = "Define port where mmar will bind to and run on server for HTTP requests."
SERVER_TCP_PORT_HELP = "Define port where mmar will bind to and run on server for TCP connections."
- CLIENT_LOCAL_PORT_HELP = "Define the port where your local dev server is running to expose through mmar."
- CLIENT_HTTP_PORT_HELP = "Define port of mmar HTTP server to make requests through the tunnel."
- CLIENT_TCP_PORT_HELP = "Define port of mmar TCP server for client to connect to, creating a tunnel."
- TUNNEL_HOST_HELP = "Define host domain of mmar server for client to connect to."
+ CLIENT_LOCAL_PORT_HELP = "Define the port where your local dev server is running to expose through mmar."
+ CLIENT_HTTP_PORT_HELP = "Define port of mmar HTTP server to make requests through the tunnel."
+ CLIENT_TCP_PORT_HELP = "Define port of mmar TCP server for client to connect to, creating a tunnel."
+ TUNNEL_HOST_HELP = "Define host domain of mmar server for client to connect to."
+ CLIENT_CUSTOM_DNS_HELP = "Define a custom DNS server that the mmar client should use when accessing your local dev server. (eg: 8.8.8.8:53, defaults to DNS in OS)"
+ CLIENT_CUSTOM_CERT_HELP = "Define path to file custom TLS certificate containing complete ASN.1 DER content (certificate, signature algorithm and signature). Currently used for testing, but may be used to allow mmar client to work with a dev server using custom TLS certificate setups. (eg: /path/to/cert)"
- TUNNEL_MESSAGE_PROTOCOL_VERSION = 1
+ TUNNEL_MESSAGE_PROTOCOL_VERSION = 4
TUNNEL_MESSAGE_DATA_DELIMITER = '\n'
ID_CHARSET = "abcdefghijklmnopqrstuvwxyz0123456789"
ID_LENGTH = 6
- MAX_TUNNELS_PER_IP = 5
- TUNNEL_RECONNECT_TIMEOUT = 3
- GRACEFUL_SHUTDOWN_TIMEOUT = 3
- TUNNEL_CREATE_TIMEOUT = 3
- REQ_BODY_READ_CHUNK_TIMEOUT = 3
- DEST_REQUEST_TIMEOUT = 30
- MAX_REQ_BODY_SIZE = 10000000 // 10mb
+ MAX_TUNNELS_PER_IP = 5
+ TUNNEL_RECONNECT_TIMEOUT = 3
+ GRACEFUL_SHUTDOWN_TIMEOUT = 3
+ TUNNEL_CREATE_TIMEOUT = 3
+ REQ_BODY_READ_CHUNK_TIMEOUT = 3
+ DEST_REQUEST_TIMEOUT = 30
+ HEARTBEAT_FROM_SERVER_TIMEOUT = 5
+ HEARTBEAT_FROM_CLIENT_TIMEOUT = 2
+ READ_DEADLINE = 3
+ MAX_REQ_BODY_SIZE = 10000000 // 10mb
+ REQUEST_ID_BUFF_SIZE = 4
CLIENT_DISCONNECT_ERR_TEXT = "Tunnel is closed, cannot connect to mmar client."
LOCALHOST_NOT_RUNNING_ERR_TEXT = "Tunneled successfully, but nothing is running on localhost."
diff --git a/docs/assets/img/mmar-demo.gif b/docs/assets/img/mmar-demo.gif
new file mode 100644
index 0000000..636a18b
Binary files /dev/null and b/docs/assets/img/mmar-demo.gif differ
diff --git a/install.sh b/install.sh
new file mode 100644
index 0000000..20efdd8
--- /dev/null
+++ b/install.sh
@@ -0,0 +1,52 @@
+#!/bin/sh
+
+set -e
+
+REPO="yusuf-musleh/mmar"
+BINARY="mmar"
+
+echo "Installing $BINARY..."
+
+# Detect OS
+OS="$(uname -s)"
+case "$OS" in
+ Linux) OS_TITLE="Linux";;
+ Darwin) OS_TITLE="Darwin";;
+ *) echo "Unsupported OS: $OS"; exit 1;;
+esac
+
+# Detect ARCH
+ARCH="$(uname -m)"
+case "$ARCH" in
+ x86_64) ARCH_ID="x86_64";;
+ i386) ARCH_ID="i386";;
+ aarch64|arm64) ARCH_ID="arm64";;
+ *) echo "Unsupported architecture: $ARCH"; exit 1;;
+esac
+
+ASSET="${BINARY}_${OS_TITLE}_${ARCH_ID}.tar.gz"
+URL="https://github.com/$REPO/releases/latest/download/$ASSET"
+
+# Temp dir
+TMP_DIR=$(mktemp -d)
+cd "$TMP_DIR"
+
+echo "Downloading $ASSET..."
+curl -sSL "$URL" -o "$ASSET"
+
+echo "Extracting..."
+tar -xzf "$ASSET"
+
+# Install location
+INSTALL_DIR="/usr/local/bin"
+
+# Ensure /usr/local/lib exists
+if [ ! -d "$INSTALL_DIR" ]; then
+ sudo mkdir -p "$INSTALL_DIR"
+fi
+
+echo "Installing to $INSTALL_DIR"
+sudo install -m 755 "$BINARY" "$INSTALL_DIR/$BINARY"
+
+echo "$BINARY installed successfully to $INSTALL_DIR/$BINARY"
+"$BINARY" version || true
diff --git a/internal/client/main.go b/internal/client/main.go
index f240ef1..2eb8a92 100644
--- a/internal/client/main.go
+++ b/internal/client/main.go
@@ -4,6 +4,8 @@ import (
"bufio"
"bytes"
"context"
+ "crypto/tls"
+ "crypto/x509"
"errors"
"fmt"
"io"
@@ -26,6 +28,8 @@ type ConfigOptions struct {
TunnelHttpPort string
TunnelTcpPort string
TunnelHost string
+ CustomDns string
+ CustomCert string
}
type MmarClient struct {
@@ -58,9 +62,64 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) {
},
}
+ // Use custom DNS if set
+ if mc.CustomDns != "" {
+ r := &net.Resolver{
+ PreferGo: true,
+ Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
+ return net.Dial("udp", mc.CustomDns)
+ },
+ }
+ dialer := &net.Dialer{
+ Resolver: r,
+ }
+
+ tp := &http.Transport{
+ DialContext: dialer.DialContext,
+ }
+
+ fwdClient.Transport = tp
+ }
+
+ // Use custom TLS certificate if setup
+ if mc.CustomCert != "" {
+ certData, certFileErr := os.ReadFile(mc.CustomCert)
+ if certFileErr != nil {
+ logger.Log(
+ constants.RED,
+ fmt.Sprintf(
+ "Could not read certificate from file: %v",
+ certFileErr,
+ ))
+ os.Exit(1)
+ }
+
+ cert, certErr := x509.ParseCertificate(certData)
+ if certErr != nil {
+ logger.Log(constants.YELLOW, "Warning: Could not load custom certificate")
+ } else {
+ fwdClient.Transport.(*http.Transport).TLSClientConfig = &tls.Config{
+ RootCAs: x509.NewCertPool(),
+ }
+ fwdClient.Transport.(*http.Transport).TLSClientConfig.RootCAs.AddCert(cert)
+ }
+ }
+
reqReader := bufio.NewReader(bytes.NewReader(tunnelMsg.MsgData))
- req, reqErr := http.ReadRequest(reqReader)
+ // Extract RequestId
+ reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE)
+ _, err := io.ReadFull(reqReader, reqIdBuff)
+ if err != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse RequestId for request: %v\n", err))
+ return
+ }
+
+ // Include RequestId in tunnel back message
+ msgData := []byte{}
+ msgData = append(msgData, reqIdBuff...)
+
+ req, reqErr := http.ReadRequest(reqReader)
if reqErr != nil {
if errors.Is(reqErr, io.EOF) {
logger.Log(constants.DEFAULT_COLOR, "Connection to mmar server closed or disconnected. Exiting...")
@@ -80,32 +139,36 @@ func (mc *MmarClient) handleRequestMessage(tunnelMsg protocol.TunnelMessage) {
resp, fwdErr := fwdClient.Do(req)
if fwdErr != nil {
if errors.Is(fwdErr, syscall.ECONNREFUSED) || errors.Is(fwdErr, io.ErrUnexpectedEOF) || errors.Is(fwdErr, io.EOF) {
- localhostNotRunningMsg := protocol.TunnelMessage{MsgType: protocol.LOCALHOST_NOT_RUNNING}
+ localhostNotRunningMsg := protocol.TunnelMessage{MsgType: protocol.LOCALHOST_NOT_RUNNING, MsgData: msgData}
if err := mc.SendMessage(localhostNotRunningMsg); err != nil {
log.Fatal(err)
}
return
} else if errors.Is(fwdErr, context.DeadlineExceeded) {
- destServerTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.DEST_REQUEST_TIMEDOUT}
+ destServerTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.DEST_REQUEST_TIMEDOUT, MsgData: msgData}
if err := mc.SendMessage(destServerTimedoutMsg); err != nil {
log.Fatal(err)
}
return
}
- log.Fatalf("Failed to forward: %v", fwdErr)
+ invalidRespFromDestMsg := protocol.TunnelMessage{MsgType: protocol.INVALID_RESP_FROM_DEST, MsgData: msgData}
+ if err := mc.SendMessage(invalidRespFromDestMsg); err != nil {
+ log.Fatal(err)
+ }
+ return
}
- logger.LogHTTP(req, resp.StatusCode, resp.ContentLength, false, true)
-
// Writing response to buffer to tunnel it back
var responseBuff bytes.Buffer
resp.Write(&responseBuff)
-
- respMessage := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()}
+ msgData = append(msgData, responseBuff.Bytes()...)
+ respMessage := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: msgData}
if err := mc.SendMessage(respMessage); err != nil {
log.Fatal(err)
}
+
+ logger.LogHTTP(req, resp.StatusCode, resp.ContentLength, false, true)
}
// Keep attempting to reconnect the existing tunnel until successful
@@ -118,7 +181,7 @@ func (mc *MmarClient) reconnectTunnel(ctx context.Context) {
logger.Log(constants.DEFAULT_COLOR, "Attempting to reconnect...")
conn, err := net.DialTimeout(
"tcp",
- fmt.Sprintf("%s:%s", mc.ConfigOptions.TunnelHost, mc.ConfigOptions.TunnelTcpPort),
+ net.JoinHostPort(mc.ConfigOptions.TunnelHost, mc.ConfigOptions.TunnelTcpPort),
constants.TUNNEL_CREATE_TIMEOUT*time.Second,
)
if err != nil {
@@ -126,6 +189,15 @@ func (mc *MmarClient) reconnectTunnel(ctx context.Context) {
continue
}
mc.Tunnel.Conn = conn
+ mc.Tunnel.Reader = bufio.NewReader(conn)
+
+ // Try to reclaim the same subdomain
+ reclaimTunnelMsg := protocol.TunnelMessage{MsgType: protocol.RECLAIM_TUNNEL, MsgData: []byte(mc.subdomain)}
+ if err := mc.SendMessage(reclaimTunnelMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, "Tunnel failed to reconnect. Exiting...")
+ os.Exit(0)
+ }
+
break
}
}
@@ -136,13 +208,35 @@ func (mc *MmarClient) ProcessTunnelMessages(ctx context.Context) {
case <-ctx.Done(): // Client gracefully shutdown
return
default:
+ // Send heartbeat if nothing has been read for a while
+ receiveMessageTimeout := time.AfterFunc(
+ constants.HEARTBEAT_FROM_CLIENT_TIMEOUT*time.Second,
+ func() {
+ heartbeatMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_FROM_CLIENT}
+ if err := mc.SendMessage(heartbeatMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, "Failed to send heartbeat. Exiting...")
+ os.Exit(0)
+ }
+ // Set a read timeout, if no response to heartbeat is recieved within that period,
+ // attempt to reconnect to the server
+ readDeadline := time.Now().Add((constants.READ_DEADLINE * time.Second))
+ mc.Tunnel.Conn.SetReadDeadline(readDeadline)
+ },
+ )
+
tunnelMsg, err := mc.ReceiveMessage()
+ // If a message is received, stop the receiveMessageTimeout and remove the ReadTimeout
+ // as we do not need to send heartbeat or check connection health in this iteration
+ receiveMessageTimeout.Stop()
+ mc.Tunnel.Conn.SetReadDeadline(time.Time{})
+
if err != nil {
// If the context was cancelled just return
if errors.Is(ctx.Err(), context.Canceled) {
return
- } else if errors.Is(err, os.ErrDeadlineExceeded) {
- continue
+ } else if errors.Is(err, protocol.INVALID_MESSAGE_PROTOCOL_VERSION) {
+ logger.Log(constants.YELLOW, "The mmar message protocol has been updated, please update mmar.")
+ os.Exit(0)
}
logger.Log(constants.DEFAULT_COLOR, "Tunnel connection disconnected.")
@@ -154,21 +248,9 @@ func (mc *MmarClient) ProcessTunnelMessages(ctx context.Context) {
}
switch tunnelMsg.MsgType {
- case protocol.CLIENT_CONNECT:
+ case protocol.TUNNEL_CREATED, protocol.TUNNEL_RECLAIMED:
tunnelSubdomain := string(tunnelMsg.MsgData)
- // If there is an existing subdomain, that means we are reconnecting with an
- // existing mmar client, try to reclaim the same subdomain
- if mc.subdomain != "" {
- reconnectMsg := protocol.TunnelMessage{MsgType: protocol.CLIENT_RECLAIM_SUBDOMAIN, MsgData: []byte(tunnelSubdomain + ":" + mc.subdomain)}
- mc.subdomain = ""
- if err := mc.SendMessage(reconnectMsg); err != nil {
- logger.Log(constants.DEFAULT_COLOR, "Tunnel failed to reconnect. Exiting...")
- os.Exit(0)
- }
- continue
- } else {
- mc.subdomain = tunnelSubdomain
- }
+ mc.subdomain = tunnelSubdomain
logger.LogTunnelCreated(tunnelSubdomain, mc.TunnelHost, mc.TunnelHttpPort, mc.LocalPort)
case protocol.CLIENT_TUNNEL_LIMIT:
limit := logger.ColorLogStr(
@@ -184,6 +266,15 @@ func (mc *MmarClient) ProcessTunnelMessages(ctx context.Context) {
os.Exit(0)
case protocol.REQUEST:
go mc.handleRequestMessage(tunnelMsg)
+ case protocol.HEARTBEAT_ACK:
+ // Got a heartbeat ack, that means the connection is healthy,
+ // we do not need to perform any action
+ case protocol.HEARTBEAT_FROM_SERVER:
+ heartbeatAckMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_ACK}
+ if err := mc.SendMessage(heartbeatAckMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, "Failed to send Heartbeat Ack. Exiting...")
+ os.Exit(0)
+ }
}
}
}
@@ -198,7 +289,7 @@ func Run(config ConfigOptions) {
conn, err := net.DialTimeout(
"tcp",
- fmt.Sprintf("%s:%s", config.TunnelHost, config.TunnelTcpPort),
+ net.JoinHostPort(config.TunnelHost, config.TunnelTcpPort),
constants.TUNNEL_CREATE_TIMEOUT*time.Second,
)
if err != nil {
@@ -215,7 +306,7 @@ func Run(config ConfigOptions) {
}
defer conn.Close()
mmarClient := MmarClient{
- protocol.Tunnel{Conn: conn},
+ protocol.Tunnel{Conn: conn, Reader: bufio.NewReader(conn)},
config,
"",
}
@@ -226,6 +317,12 @@ func Run(config ConfigOptions) {
// Process Tunnel Messages coming from mmar server
go mmarClient.ProcessTunnelMessages(ctx)
+ createTunnelMsg := protocol.TunnelMessage{MsgType: protocol.CREATE_TUNNEL}
+ if err := mmarClient.SendMessage(createTunnelMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, "Failed to create Tunnel. Exiting...")
+ os.Exit(0)
+ }
+
// Wait for an interrupt signal, if received, terminate gracefully
<-sigInt
diff --git a/internal/protocol/main.go b/internal/protocol/main.go
index c0b3241..1287734 100644
--- a/internal/protocol/main.go
+++ b/internal/protocol/main.go
@@ -19,12 +19,18 @@ import (
const (
REQUEST = uint8(iota + 1)
RESPONSE
- CLIENT_CONNECT
- CLIENT_RECLAIM_SUBDOMAIN
+ CREATE_TUNNEL
+ RECLAIM_TUNNEL
+ TUNNEL_CREATED
+ TUNNEL_RECLAIMED
CLIENT_DISCONNECT
CLIENT_TUNNEL_LIMIT
LOCALHOST_NOT_RUNNING
DEST_REQUEST_TIMEDOUT
+ HEARTBEAT_FROM_CLIENT
+ HEARTBEAT_FROM_SERVER
+ HEARTBEAT_ACK
+ INVALID_RESP_FROM_DEST
)
var INVALID_MESSAGE_PROTOCOL_VERSION = errors.New("Invalid Message Protocol Version")
@@ -33,7 +39,7 @@ var INVALID_MESSAGE_TYPE = errors.New("Invalid Tunnel Message Type")
func isValidTunnelMessageType(mt uint8) (uint8, error) {
// Iterate through all the message type, from first to last, checking
// if the provided message type matches one of them
- for msgType := REQUEST; msgType <= DEST_REQUEST_TIMEDOUT; msgType++ {
+ for msgType := REQUEST; msgType <= INVALID_RESP_FROM_DEST; msgType++ {
if mt == msgType {
return msgType, nil
}
@@ -45,9 +51,10 @@ func isValidTunnelMessageType(mt uint8) (uint8, error) {
func TunnelErrState(errState uint8) string {
// TODO: Have nicer/more elaborative error messages/pages
errStates := map[uint8]string{
- CLIENT_DISCONNECT: constants.CLIENT_DISCONNECT_ERR_TEXT,
- LOCALHOST_NOT_RUNNING: constants.LOCALHOST_NOT_RUNNING_ERR_TEXT,
- DEST_REQUEST_TIMEDOUT: constants.DEST_REQUEST_TIMEDOUT_ERR_TEXT,
+ CLIENT_DISCONNECT: constants.CLIENT_DISCONNECT_ERR_TEXT,
+ LOCALHOST_NOT_RUNNING: constants.LOCALHOST_NOT_RUNNING_ERR_TEXT,
+ DEST_REQUEST_TIMEDOUT: constants.DEST_REQUEST_TIMEDOUT_ERR_TEXT,
+ INVALID_RESP_FROM_DEST: constants.READ_RESP_BODY_ERR_TEXT,
}
fallbackErr := "An error occured while attempting to tunnel."
@@ -71,6 +78,7 @@ type Tunnel struct {
Id string
Conn net.Conn
CreatedOn time.Time
+ Reader *bufio.Reader
}
type TunnelInterface interface {
@@ -173,6 +181,10 @@ func (tm *TunnelMessage) deserializeMessage(reader *bufio.Reader) error {
return nil
}
+func (t *Tunnel) ReservedSubdomain() bool {
+ return t.Id != ""
+}
+
func (t *Tunnel) SendMessage(tunnelMsg TunnelMessage) error {
// Serialize tunnel message data
serializedMsg, serializeErr := tunnelMsg.serializeMessage()
@@ -184,11 +196,9 @@ func (t *Tunnel) SendMessage(tunnelMsg TunnelMessage) error {
}
func (t *Tunnel) ReceiveMessage() (TunnelMessage, error) {
- msgReader := bufio.NewReader(t.Conn)
-
// Read and deserialize tunnel message data
tunnelMessage := TunnelMessage{}
- deserializeErr := tunnelMessage.deserializeMessage(msgReader)
+ deserializeErr := tunnelMessage.deserializeMessage(t.Reader)
return tunnelMessage, deserializeErr
}
diff --git a/internal/server/main.go b/internal/server/main.go
index fee090f..594ba88 100644
--- a/internal/server/main.go
+++ b/internal/server/main.go
@@ -4,19 +4,18 @@ import (
"bufio"
"bytes"
"context"
+ "encoding/binary"
"encoding/json"
"errors"
"fmt"
"html"
"io"
"log"
- "math/rand"
"net"
"net/http"
"os"
"os/signal"
"slices"
- "strings"
"sync"
"time"
@@ -44,7 +43,6 @@ type IncomingRequest struct {
responseWriter http.ResponseWriter
request *http.Request
cancel context.CancelCauseFunc
- serializedReq []byte
ctx context.Context
}
@@ -53,11 +51,14 @@ type OutgoingResponse struct {
body []byte
}
+type RequestId uint32
+
// Tunnel to Client
type ClientTunnel struct {
protocol.Tunnel
- incomingChannel chan IncomingRequest
- outgoingChannel chan protocol.TunnelMessage
+ incomingChannel chan IncomingRequest
+ outgoingChannel chan protocol.TunnelMessage
+ inflightRequests *sync.Map
}
func (ct *ClientTunnel) drainChannels() {
@@ -119,6 +120,16 @@ func (ct *ClientTunnel) close(graceful bool) {
)
}
+// Generate unique request id for incoming request for client
+func (ct *ClientTunnel) GenerateUniqueRequestID() RequestId {
+ var generatedReqId RequestId
+
+ for _, exists := ct.inflightRequests.Load(generatedReqId); exists || generatedReqId == 0; {
+ generatedReqId = RequestId(GenerateRandomUint32())
+ }
+ return generatedReqId
+}
+
// Serves simple stats for mmar server behind Basic Authentication
func (ms *MmarServer) handleServerStats(w http.ResponseWriter, r *http.Request) {
// Check Basic Authentication
@@ -193,19 +204,33 @@ func (ms *MmarServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Create response channel to receive response for tunneled request
respChannel := make(chan OutgoingResponse)
- // Tunnel the request
- clientTunnel.incomingChannel <- IncomingRequest{
+ // Add request to client's inflight requests
+ reqId := clientTunnel.GenerateUniqueRequestID()
+ incomingReq := IncomingRequest{
responseChannel: respChannel,
responseWriter: w,
request: r,
cancel: cancel,
- serializedReq: serializedRequest,
ctx: ctx,
}
+ clientTunnel.inflightRequests.Store(reqId, incomingReq)
+
+ // Construct Request message data
+ reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE)
+ binary.LittleEndian.PutUint32(reqIdBuff, uint32(reqId))
+ reqMsgData := append(reqIdBuff, serializedRequest...)
+
+ // Tunnel the request to mmar client
+ reqMessage := protocol.TunnelMessage{MsgType: protocol.REQUEST, MsgData: reqMsgData}
+ if err := clientTunnel.SendMessage(reqMessage); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send Request msg to client: %v", err))
+ cancel(FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR)
+ }
select {
case <-ctx.Done(): // Request is canceled or Tunnel is closed if context is canceled
handleCancel(context.Cause(ctx), w)
+ clientTunnel.inflightRequests.Delete(reqId)
return
case resp, _ := <-respChannel: // Await response for tunneled request
// Add header to close the connection
@@ -220,20 +245,15 @@ func (ms *MmarServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
-func (ms *MmarServer) GenerateUniqueId() string {
- reservedIDs := []string{"", "admin", "stats"}
+func (ms *MmarServer) GenerateUniqueSubdomain() string {
+ reservedSubdomains := []string{"", "admin", "stats"}
- generatedId := ""
- for _, exists := ms.clients[generatedId]; exists || slices.Contains(reservedIDs, generatedId); {
- var randSeed *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))
- b := make([]byte, constants.ID_LENGTH)
- for i := range b {
- b[i] = constants.ID_CHARSET[randSeed.Intn(len(constants.ID_CHARSET))]
- }
- generatedId = string(b)
+ generatedSubdomain := ""
+ for _, exists := ms.clients[generatedSubdomain]; exists || slices.Contains(reservedSubdomains, generatedSubdomain); {
+ generatedSubdomain = GenerateRandomID()
}
- return generatedId
+ return generatedSubdomain
}
func (ms *MmarServer) TunnelLimitedIP(ip string) bool {
@@ -247,31 +267,40 @@ func (ms *MmarServer) TunnelLimitedIP(ip string) bool {
return len(tunnels) >= constants.MAX_TUNNELS_PER_IP
}
-func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) {
+func (ms *MmarServer) newClientTunnel(tunnel protocol.Tunnel, subdomain string) (*ClientTunnel, error) {
// Acquire lock to create new client tunnel data
ms.mu.Lock()
- // Generate unique ID for client
- uniqueId := ms.GenerateUniqueId()
- tunnel := protocol.Tunnel{
- Id: uniqueId,
- Conn: conn,
- CreatedOn: time.Now(),
+ var uniqueSubdomain string
+ var msgType uint8
+ if subdomain != "" {
+ uniqueSubdomain = subdomain
+ msgType = protocol.TUNNEL_RECLAIMED
+ } else {
+ // Generate unique subdomain for client if not passed in
+ uniqueSubdomain = ms.GenerateUniqueSubdomain()
+ msgType = protocol.TUNNEL_CREATED
}
+ tunnel.Id = uniqueSubdomain
+
// Create channels to tunnel requests to and recieve responses from
incomingChannel := make(chan IncomingRequest)
outgoingChannel := make(chan protocol.TunnelMessage)
+ // Initialize inflight requests map for client tunnel
+ var inflightRequests sync.Map
+
// Create client tunnel
clientTunnel := ClientTunnel{
tunnel,
incomingChannel,
outgoingChannel,
+ &inflightRequests,
}
// Check if IP reached max tunnel limit
- clientIP := utils.ExtractIP(conn.RemoteAddr().String())
+ clientIP := utils.ExtractIP(tunnel.Conn.RemoteAddr().String())
limitedIP := ms.TunnelLimitedIP(clientIP)
// If so, send limit message to client and close client tunnel
if limitedIP {
@@ -286,18 +315,18 @@ func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) {
}
// Add client tunnel to clients
- ms.clients[uniqueId] = clientTunnel
+ ms.clients[uniqueSubdomain] = clientTunnel
// Associate tunnel with client IP
- ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], uniqueId)
+ ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], uniqueSubdomain)
// Release lock once created
ms.mu.Unlock()
- // Send unique ID to client
- connMessage := protocol.TunnelMessage{MsgType: protocol.CLIENT_CONNECT, MsgData: []byte(uniqueId)}
+ // Send unique subdomain to client
+ connMessage := protocol.TunnelMessage{MsgType: msgType, MsgData: []byte(uniqueSubdomain)}
if err := clientTunnel.SendMessage(connMessage); err != nil {
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique ID msg to client: %v", err))
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique subdomain msg to client: %v", err))
return nil, err
}
@@ -306,32 +335,18 @@ func (ms *MmarServer) newClientTunnel(conn net.Conn) (*ClientTunnel, error) {
func (ms *MmarServer) handleTcpConnection(conn net.Conn) {
- clientTunnel, err := ms.newClientTunnel(conn)
-
- if err != nil {
- if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) {
- // Close the connection when client max tunnels limit reached
- conn.Close()
- return
- }
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to create ClientTunnel: %v", err))
- return
+ tunnel := protocol.Tunnel{
+ Conn: conn,
+ CreatedOn: time.Now(),
+ Reader: bufio.NewReader(conn),
}
- logger.Log(
- constants.DEFAULT_COLOR,
- fmt.Sprintf(
- "[%s] Tunnel created: %s",
- clientTunnel.Tunnel.Id,
- conn.RemoteAddr().String(),
- ),
- )
-
// Process Tunnel Messages coming from mmar client
- go ms.processTunnelMessages(clientTunnel)
+ go ms.processTunnelMessages(tunnel)
+}
- // Start goroutine to process tunneled requests
- go ms.processTunneledRequestsForClient(clientTunnel)
+func (ms *MmarServer) closeTunnel(t *protocol.Tunnel) {
+ t.Conn.Close()
}
func (ms *MmarServer) closeClientTunnel(ct *ClientTunnel) {
@@ -351,168 +366,216 @@ func (ms *MmarServer) closeClientTunnel(ct *ClientTunnel) {
ct.close(true)
}
-func (ms *MmarServer) processTunneledRequestsForClient(ct *ClientTunnel) {
- for {
- // Read requests coming in tunnel channel
- incomingReq, ok := <-ct.incomingChannel
- if !ok {
- // Channel closed, client disconencted, shutdown goroutine
- return
- }
+func (ms *MmarServer) closeClientTunnelOrConn(ct *ClientTunnel, t protocol.Tunnel) {
- // Forward the request to mmar client
- reqMessage := protocol.TunnelMessage{MsgType: protocol.REQUEST, MsgData: incomingReq.serializedReq}
- if err := ct.SendMessage(reqMessage); err != nil {
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send Request msg to client: %v", err))
- incomingReq.cancel(FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR)
- continue
- }
+ // If client has not reserved subdomain, just close the tcp connection
+ if !ct.ReservedSubdomain() {
+ ms.closeTunnel(&t)
+ return
+ }
- // Wait for response for this request to come back from outgoing channel
- respTunnelMsg, ok := <-ct.outgoingChannel
- if !ok {
- // Channel closed, client disconencted, shutdown goroutine
- return
- }
+ ms.closeClientTunnel(ct)
+}
- // Read response for forwarded request
- respReader := bufio.NewReader(bytes.NewReader(respTunnelMsg.MsgData))
- resp, respErr := http.ReadResponse(respReader, incomingReq.request)
+func (ms *MmarServer) handleResponseMessages(ct *ClientTunnel, tunnelMsg protocol.TunnelMessage) {
+ respReader := bufio.NewReader(bytes.NewReader(tunnelMsg.MsgData))
- if respErr != nil {
- if errors.Is(respErr, io.ErrUnexpectedEOF) || errors.Is(respErr, net.ErrClosed) {
- incomingReq.cancel(CLIENT_DISCONNECTED_ERR)
- ms.closeClientTunnel(ct)
- return
- }
- failedReq := fmt.Sprintf("%s - %s%s", incomingReq.request.Method, html.EscapeString(incomingReq.request.URL.Path), incomingReq.request.URL.RawQuery)
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to return response: %v\n\n for req: %v", respErr, failedReq))
- incomingReq.cancel(FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR)
- continue
- }
+ // Extract RequestId
+ reqIdBuff := make([]byte, constants.REQUEST_ID_BUFF_SIZE)
+ _, err := io.ReadFull(respReader, reqIdBuff)
+ if err != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] - Failed to parse RequestId for response: %v\n", ct.Tunnel.Id, err))
+ return
+ }
- respBody, respBodyErr := io.ReadAll(resp.Body)
- if respBodyErr != nil {
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse response body: %v\n\n", respBodyErr))
- incomingReq.cancel(READ_RESP_BODY_ERR)
- continue
- }
+ // Get Inflight Request and remove it from inflight requests
+ reqId := RequestId(binary.LittleEndian.Uint32(reqIdBuff))
+ inflight, loaded := ct.inflightRequests.LoadAndDelete(reqId)
+ if !loaded {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] Failed to identify inflight request: %v", ct.Tunnel.Id, reqId))
+ return
+ }
- // Set headers for response
- for hKey, hVal := range resp.Header {
- incomingReq.responseWriter.Header().Set(hKey, hVal[0])
- // Add remaining values for header if more than than one exists
- for i := 1; i < len(hVal); i++ {
- incomingReq.responseWriter.Header().Add(hKey, hVal[i])
- }
+ inflightRequest, ok := inflight.(IncomingRequest)
+ if !ok {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("[%s] Failed to parse inflight request: %v", ct.Tunnel.Id, reqId))
+ return
+ }
+
+ // Read response for forwarded request
+ resp, respErr := http.ReadResponse(respReader, inflightRequest.request)
+
+ if respErr != nil {
+ if errors.Is(respErr, io.ErrUnexpectedEOF) || errors.Is(respErr, net.ErrClosed) {
+ inflightRequest.cancel(CLIENT_DISCONNECTED_ERR)
+ ms.closeClientTunnel(ct)
+ return
}
+ failedReq := fmt.Sprintf("%s - %s%s", inflightRequest.request.Method, html.EscapeString(inflightRequest.request.URL.Path), inflightRequest.request.URL.RawQuery)
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to return response: %v\n\n for req: %v", respErr, failedReq))
+ inflightRequest.cancel(FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR)
+ return
+ }
- // Close response body
- resp.Body.Close()
+ respBody, respBodyErr := io.ReadAll(resp.Body)
+ if respBodyErr != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to parse response body: %v\n\n", respBodyErr))
+ inflightRequest.cancel(READ_RESP_BODY_ERR)
+ return
+ }
- select {
- case <-incomingReq.ctx.Done():
- // Request is canceled, on to the next request
- continue
- case incomingReq.responseChannel <- OutgoingResponse{statusCode: resp.StatusCode, body: respBody}:
- // Send response data back
+ defer resp.Body.Close()
+
+ // Set headers for response
+ for hKey, hVal := range resp.Header {
+ inflightRequest.responseWriter.Header().Set(hKey, hVal[0])
+ // Add remaining values for header if more than than one exists
+ for i := 1; i < len(hVal); i++ {
+ inflightRequest.responseWriter.Header().Add(hKey, hVal[i])
}
}
+
+ select {
+ case <-inflightRequest.ctx.Done():
+ // Request is canceled, do nothing
+ return
+ case inflightRequest.responseChannel <- OutgoingResponse{statusCode: resp.StatusCode, body: respBody}:
+ // Send response data back
+ }
}
-func (ms *MmarServer) processTunnelMessages(ct *ClientTunnel) {
+func (ms *MmarServer) processTunnelMessages(t protocol.Tunnel) {
+ var ct *ClientTunnel
for {
- tunnelMsg, err := ct.ReceiveMessage()
+ // Send heartbeat if nothing has been read for a while
+ receiveMessageTimeout := time.AfterFunc(
+ constants.HEARTBEAT_FROM_SERVER_TIMEOUT*time.Second,
+ func() {
+ heartbeatMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_FROM_SERVER}
+ if err := t.SendMessage(heartbeatMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send heartbeat: %v", err))
+ ms.closeClientTunnelOrConn(ct, t)
+ return
+ }
+ // Set a read timeout, if no response to heartbeat is received within that period,
+ // that means the client has disconnected
+ readDeadline := time.Now().Add((constants.READ_DEADLINE * time.Second))
+ t.Conn.SetReadDeadline(readDeadline)
+ },
+ )
+
+ tunnelMsg, err := t.ReceiveMessage()
+ // If a message is received, stop the receiveMessageTimeout and remove the ReadTimeout
+ // as we do not need to send heartbeat or check connection health in this iteration
+ receiveMessageTimeout.Stop()
+ t.Conn.SetReadDeadline(time.Time{})
+
if err != nil {
logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Receive Message from client tunnel errored: %v", err))
if utils.NetworkError(err) {
// If error with connection, stop processing messages
- ms.closeClientTunnel(ct)
+ ms.closeClientTunnelOrConn(ct, t)
return
}
continue
}
switch tunnelMsg.MsgType {
- case protocol.RESPONSE:
- ct.outgoingChannel <- tunnelMsg
- case protocol.LOCALHOST_NOT_RUNNING:
- // Create a response for Tunnel connected but localhost not running
- errState := protocol.TunnelErrState(protocol.LOCALHOST_NOT_RUNNING)
- resp := http.Response{
- Status: "200 OK",
- StatusCode: http.StatusOK,
- Body: io.NopCloser(bytes.NewBufferString(errState)),
- }
+ case protocol.CREATE_TUNNEL:
+ // mmar client requesting new tunnel
+ ct, err = ms.newClientTunnel(t, "")
- // Writing response to buffer to tunnel it back
- var responseBuff bytes.Buffer
- resp.Write(&responseBuff)
- notRunningMsg := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()}
- ct.outgoingChannel <- notRunningMsg
- case protocol.DEST_REQUEST_TIMEDOUT:
- // Create a response for Tunnel connected but localhost took too long to respond
- errState := protocol.TunnelErrState(protocol.DEST_REQUEST_TIMEDOUT)
- resp := http.Response{
- Status: "200 OK",
- StatusCode: http.StatusOK,
- Body: io.NopCloser(bytes.NewBufferString(errState)),
+ if err != nil {
+ if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) {
+ // Close the connection when client max tunnels limit reached
+ t.Conn.Close()
+ return
+ }
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to create ClientTunnel: %v", err))
+ return
}
- // Writing response to buffer to tunnel it back
- var responseBuff bytes.Buffer
- resp.Write(&responseBuff)
- destTimedoutMsg := protocol.TunnelMessage{MsgType: protocol.RESPONSE, MsgData: responseBuff.Bytes()}
- ct.outgoingChannel <- destTimedoutMsg
- case protocol.CLIENT_DISCONNECT:
- ms.closeClientTunnel(ct)
- return
- case protocol.CLIENT_RECLAIM_SUBDOMAIN:
- newAndExistingIDs := strings.Split(string(tunnelMsg.MsgData), ":")
- newId := newAndExistingIDs[0]
- existingId := newAndExistingIDs[1]
+ logger.Log(
+ constants.DEFAULT_COLOR,
+ fmt.Sprintf(
+ "[%s] Tunnel created: %s",
+ ct.Tunnel.Id,
+ t.Conn.RemoteAddr().String(),
+ ),
+ )
+ case protocol.RECLAIM_TUNNEL:
+ // mmar client reclaiming a previously created tunnel
+ existingId := string(tunnelMsg.MsgData)
// Check if the subdomain has already been taken
_, ok := ms.clients[existingId]
if ok {
// if so, close the tunnel, so the user can create a new one
- ms.closeClientTunnel(ct)
+ ms.closeClientTunnelOrConn(ct, t)
return
}
- ct.Tunnel.Id = existingId
-
- // Add existing client tunnel to clients
- ms.clients[existingId] = *ct
-
- // Remove newId tunnel from clients
- delete(ms.clients, newId)
-
- // Update the tunnels for the IP
- clientIP := utils.ExtractIP(ct.Conn.RemoteAddr().String())
- newIdIndex := slices.Index(ms.tunnelsPerIP[clientIP], newId)
- if newIdIndex == -1 {
- ms.tunnelsPerIP[clientIP] = append(ms.tunnelsPerIP[clientIP], existingId)
- } else {
- ms.tunnelsPerIP[clientIP][newIdIndex] = existingId
- }
-
- connMessage := protocol.TunnelMessage{MsgType: protocol.CLIENT_CONNECT, MsgData: []byte(existingId)}
- if err := ct.SendMessage(connMessage); err != nil {
- logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to send unique ID msg to client: %v", err))
- ms.closeClientTunnel(ct)
+ ct, err = ms.newClientTunnel(t, existingId)
+ if err != nil {
+ if errors.Is(err, CLIENT_MAX_TUNNELS_REACHED) {
+ // Close the connection when client max tunnels limit reached
+ t.Conn.Close()
+ return
+ }
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to reclaim ClientTunnel: %v", err))
return
}
logger.Log(
constants.DEFAULT_COLOR,
fmt.Sprintf(
- "[%s] Tunnel reclaimed: %s -> %s",
- newId,
- ct.Conn.RemoteAddr().String(),
+ "[%s] Tunnel reclaimed: %s",
existingId,
+ ct.Conn.RemoteAddr().String(),
),
)
+ case protocol.RESPONSE:
+ go ms.handleResponseMessages(ct, tunnelMsg)
+ case protocol.LOCALHOST_NOT_RUNNING:
+ // Create a response for Tunnel connected but localhost not running
+ errState := protocol.TunnelErrState(protocol.LOCALHOST_NOT_RUNNING)
+ responseBuff := createSerializedServerResp("200 OK", http.StatusOK, errState)
+ notRunningMsg := protocol.TunnelMessage{
+ MsgType: protocol.RESPONSE,
+ MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...),
+ }
+ go ms.handleResponseMessages(ct, notRunningMsg)
+ case protocol.DEST_REQUEST_TIMEDOUT:
+ // Create a response for Tunnel connected but localhost took too long to respond
+ errState := protocol.TunnelErrState(protocol.DEST_REQUEST_TIMEDOUT)
+ responseBuff := createSerializedServerResp("200 OK", http.StatusOK, errState)
+ destTimedoutMsg := protocol.TunnelMessage{
+ MsgType: protocol.RESPONSE,
+ MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...),
+ }
+ go ms.handleResponseMessages(ct, destTimedoutMsg)
+ case protocol.CLIENT_DISCONNECT:
+ ms.closeClientTunnelOrConn(ct, t)
+ return
+ case protocol.HEARTBEAT_FROM_CLIENT:
+ heartbeatAckMsg := protocol.TunnelMessage{MsgType: protocol.HEARTBEAT_ACK}
+ if err := t.SendMessage(heartbeatAckMsg); err != nil {
+ logger.Log(constants.DEFAULT_COLOR, fmt.Sprintf("Failed to heartbeat ack to client: %v", err))
+ ms.closeClientTunnelOrConn(ct, t)
+ return
+ }
+ case protocol.HEARTBEAT_ACK:
+ // Got a heartbeat ack, that means the connection is healthy,
+ // we do not need to perform any action
+ case protocol.INVALID_RESP_FROM_DEST:
+ // Create a response for receiving invalid response from destination server
+ errState := protocol.TunnelErrState(protocol.INVALID_RESP_FROM_DEST)
+ responseBuff := createSerializedServerResp("500 Internal Server Error", http.StatusInternalServerError, errState)
+ invalidRespFromDestMsg := protocol.TunnelMessage{
+ MsgType: protocol.RESPONSE,
+ MsgData: append(tunnelMsg.MsgData, responseBuff.Bytes()...),
+ }
+ go ms.handleResponseMessages(ct, invalidRespFromDestMsg)
}
}
}
diff --git a/internal/server/utils.go b/internal/server/utils.go
index 2fd144b..41d9b68 100644
--- a/internal/server/utils.go
+++ b/internal/server/utils.go
@@ -3,9 +3,12 @@ package server
import (
"bytes"
"context"
+ cryptoRand "crypto/rand"
+ "encoding/binary"
"errors"
"fmt"
"io"
+ mathRand "math/rand"
"net/http"
"strconv"
"time"
@@ -21,7 +24,7 @@ var MAX_REQ_BODY_SIZE_ERR error = errors.New(constants.MAX_REQ_BODY_SIZE_ERR_TEX
var FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR error = errors.New(constants.FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR_TEXT)
var FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR error = errors.New(constants.FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR_TEXT)
-func responseWith(respText string, w http.ResponseWriter, statusCode int) {
+func respondWith(respText string, w http.ResponseWriter, statusCode int) {
w.Header().Set("Content-Length", strconv.Itoa(len(respText)))
w.Header().Set("Connection", "close")
w.WriteHeader(statusCode)
@@ -34,15 +37,15 @@ func handleCancel(cause error, w http.ResponseWriter) {
// Cancelled, do nothing
return
case READ_BODY_CHUNK_TIMEOUT_ERR:
- responseWith(cause.Error(), w, http.StatusRequestTimeout)
+ respondWith(cause.Error(), w, http.StatusRequestTimeout)
case READ_BODY_CHUNK_ERR, CLIENT_DISCONNECTED_ERR:
- responseWith(cause.Error(), w, http.StatusBadRequest)
+ respondWith(cause.Error(), w, http.StatusBadRequest)
case READ_RESP_BODY_ERR:
- responseWith(cause.Error(), w, http.StatusInternalServerError)
+ respondWith(cause.Error(), w, http.StatusInternalServerError)
case MAX_REQ_BODY_SIZE_ERR:
- responseWith(cause.Error(), w, http.StatusRequestEntityTooLarge)
+ respondWith(cause.Error(), w, http.StatusRequestEntityTooLarge)
case FAILED_TO_FORWARD_TO_MMAR_CLIENT_ERR, FAILED_TO_READ_RESP_FROM_MMAR_CLIENT_ERR:
- responseWith(cause.Error(), w, http.StatusServiceUnavailable)
+ respondWith(cause.Error(), w, http.StatusServiceUnavailable)
}
}
@@ -119,3 +122,35 @@ func serializeRequest(ctx context.Context, r *http.Request, cancel context.Cance
// Send serialized request through channel
serializedRequestChannel <- requestBuff.Bytes()
}
+
+// Create HTTP response sent from mmar server to the end-user client
+func createSerializedServerResp(status string, statusCode int, body string) bytes.Buffer {
+ resp := http.Response{
+ Status: status,
+ StatusCode: statusCode,
+ Body: io.NopCloser(bytes.NewBufferString(body)),
+ }
+
+ // Writing response to buffer to tunnel it back
+ var responseBuff bytes.Buffer
+ resp.Write(&responseBuff)
+
+ return responseBuff
+}
+
+// Generate a random ID from ID_CHARSET of length ID_LENGTH
+func GenerateRandomID() string {
+ var randSeed *mathRand.Rand = mathRand.New(mathRand.NewSource(time.Now().UnixNano()))
+ b := make([]byte, constants.ID_LENGTH)
+ for i := range b {
+ b[i] = constants.ID_CHARSET[randSeed.Intn(len(constants.ID_CHARSET))]
+ }
+ return string(b)
+}
+
+// Generate a random 32-bit unsigned integer
+func GenerateRandomUint32() uint32 {
+ var randomUint32 uint32
+ binary.Read(cryptoRand.Reader, binary.BigEndian, &randomUint32)
+ return randomUint32
+}
diff --git a/internal/utils/main.go b/internal/utils/main.go
index a35d885..b9aa575 100644
--- a/internal/utils/main.go
+++ b/internal/utils/main.go
@@ -112,5 +112,14 @@ func NetworkError(err error) bool {
return errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrUnexpectedEOF) ||
errors.Is(err, net.ErrClosed) ||
- errors.Is(err, syscall.ECONNRESET)
+ errors.Is(err, syscall.ECONNRESET) ||
+ errors.Is(err, os.ErrDeadlineExceeded)
+}
+
+func EnvVarOrDefault(envVar string, defaultVal string) string {
+ envValue, ok := os.LookupEnv(envVar)
+ if !ok {
+ return defaultVal
+ }
+ return envValue
}
diff --git a/simulations/devserver/main.go b/simulations/devserver/main.go
index d7080ca..973d42a 100644
--- a/simulations/devserver/main.go
+++ b/simulations/devserver/main.go
@@ -25,11 +25,19 @@ type DevServer struct {
*httptest.Server
}
-func NewDevServer() *DevServer {
+func NewDevServer(proto string, addr string) *DevServer {
mux := setupMux()
+ var httpServer *httptest.Server
+ switch proto {
+ case "https":
+ httpServer = httptest.NewTLSServer(mux)
+ case "http":
+ httpServer = httptest.NewServer(mux)
+ }
+
return &DevServer{
- httptest.NewServer(mux),
+ httpServer,
}
}
@@ -170,19 +178,27 @@ func handleRedirect(w http.ResponseWriter, r *http.Request) {
// Request handler that returns an invalid HTTP response
func handleBadResp(w http.ResponseWriter, r *http.Request) {
- // Return a response with Content-Length headers that do not match the actual data
- respBody, err := json.Marshal(map[string]interface{}{
- "data": "some data",
- })
+ // Get the underlying connection object
+ // Assert that w supports Hijacking
+ hijacker, ok := w.(http.Hijacker)
+ if !ok {
+ http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
+ return
+ }
+ // Hijack the connection
+ conn, buf, err := hijacker.Hijack()
if err != nil {
- log.Fatalf("Failed to marshal response for GET: %v", err)
+ http.Error(w, "Hijacking failed", http.StatusInternalServerError)
+ return
}
+ defer conn.Close()
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("Content-Length", "123") // Content length much larger than actual content
- w.WriteHeader(http.StatusOK)
- w.Write(respBody)
+ // Send back an invalid HTTP response
+ buf.WriteString("some random string\r\n" +
+ "\r\n" +
+ "that is not a valid http resp")
+ buf.Flush()
}
// Request handler that takes a long time before returning response
diff --git a/simulations/simulation_test.go b/simulations/simulation_test.go
index 1fc2173..174efec 100644
--- a/simulations/simulation_test.go
+++ b/simulations/simulation_test.go
@@ -43,7 +43,15 @@ func StartMmarServer(ctx context.Context) {
}
}
-func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort string) {
+func StartMmarClient(
+ ctx context.Context,
+ urlCh chan string,
+ localDevServerPort string,
+ localDevServerHost string,
+ localDevServerProto string,
+ customDns string,
+ customCert string,
+) {
cmd := exec.CommandContext(
ctx,
"./mmar",
@@ -54,6 +62,26 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort
localDevServerPort,
)
+ if localDevServerHost != "" {
+ cmd.Args = append(cmd.Args, "--local-host", localDevServerHost)
+ }
+
+ if localDevServerProto != "" {
+ cmd.Args = append(cmd.Args, "--local-proto", localDevServerProto)
+ }
+
+ if customDns != "" {
+ cmd.Args = append(cmd.Args, "--custom-dns", customDns)
+ }
+
+ if customCert != "" {
+ cmd.Args = append(cmd.Args, "--custom-cert", customCert)
+ }
+
+ cmd.Args = append(cmd.Args, "")
+
+ cmd.Stdout = os.Stdout
+
// Pipe Stderr To capture logs for extracting the tunnel url
pipe, _ := cmd.StderrPipe()
@@ -77,7 +105,6 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort
tunnelUrl := extractTunnelURL(line)
if tunnelUrl != "" {
urlCh <- tunnelUrl
- break
}
line, readErr = stdoutReader.ReadString('\n')
}
@@ -91,9 +118,9 @@ func StartMmarClient(ctx context.Context, urlCh chan string, localDevServerPort
}
}
-func StartLocalDevServer() *devserver.DevServer {
- ds := devserver.NewDevServer()
- log.Printf("Started local dev server on: http://localhost:%v", ds.Port())
+func StartLocalDevServer(proto string, addr string) *devserver.DevServer {
+ ds := devserver.NewDevServer(proto, addr)
+ log.Printf("Started local dev server on: %v://%v:%v", proto, addr, ds.Port())
return ds
}
@@ -115,12 +142,13 @@ func verifyGetRequestSuccess(t *testing.T, client *http.Client, tunnelUrl string
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-get-request-success"},
}
@@ -159,12 +187,13 @@ func verifyGetRequestFail(t *testing.T, client *http.Client, tunnelUrl string, w
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-get-request-fail"},
}
@@ -210,12 +239,13 @@ func verifyPostRequestSuccess(t *testing.T, client *http.Client, tunnelUrl strin
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-post-request-success"},
"Content-Length": {strconv.Itoa(len(serializedReqBody))},
}
@@ -265,12 +295,13 @@ func verifyPostRequestFail(t *testing.T, client *http.Client, tunnelUrl string,
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-post-request-fail"},
"Content-Length": {strconv.Itoa(len(serializedReqBody))},
}
@@ -319,12 +350,13 @@ func verifyRedirectsHandled(t *testing.T, client *http.Client, tunnelUrl string,
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-redirect-request"},
"Referer": {tunnelUrl + "/redirect"}, // Include referer header since it redirects
}
@@ -378,6 +410,7 @@ func verifyInvalidMethodRequestHandled(t *testing.T, client *http.Client, tunnel
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-invalid-method-request"},
}
@@ -450,7 +483,6 @@ func verifyInvalidHttpVersionRequestHandled(t *testing.T, tunnelUrl string, wg *
if respErr != nil {
t.Errorf("%v: Failed to get response %v", "verifyInvalidHttpVersionRequestHandled", respErr)
}
-
if resp.StatusCode != http.StatusBadRequest {
t.Errorf(
"%v: resp.StatusCode = %v; want %v",
@@ -546,7 +578,6 @@ func verifyContentLengthWithNoBodyRequestHandled(t *testing.T, tunnelUrl string,
if respErr != nil {
t.Errorf("%v: Failed to get response %v", "verifyContentLengthWithNoBodyRequestHandled", respErr)
}
-
expectedBody := constants.READ_BODY_CHUNK_TIMEOUT_ERR_TEXT
expectedResp := expectedResponse{
@@ -580,12 +611,13 @@ func verifyRequestWithLargeBody(t *testing.T, client *http.Client, tunnelUrl str
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedReqHeaders := map[string][]string{
"User-Agent": {"Go-http-client/1.1"}, // Default header in golang client
"Accept-Encoding": {"gzip"}, // Default header in golang client
+ "Connection": {"close"},
"Simulation-Test": {"verify-large-post-request-success"},
"Content-Length": {strconv.Itoa(len(serializedReqBody))},
}
@@ -634,7 +666,11 @@ func verifyRequestWithVeryLargeBody(t *testing.T, client *http.Client, tunnelUrl
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ // Check if connection was closed in the middle of writing, that's also valid behavior
+ if !strings.Contains(respErr.Error(), "write: connection reset by peer") {
+ t.Errorf("Failed to get response: %v", respErr)
+ }
+ return
}
expectedBody := constants.MAX_REQ_BODY_SIZE_ERR_TEXT
@@ -661,7 +697,7 @@ func verifyDevServerReturningInvalidRespHandled(t *testing.T, client *http.Clien
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedBody := constants.READ_RESP_BODY_ERR_TEXT
@@ -688,7 +724,7 @@ func verifyDevServerLongRunningReqHandledGradefully(t *testing.T, client *http.C
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedBody := constants.DEST_REQUEST_TIMEDOUT_ERR_TEXT
@@ -715,7 +751,7 @@ func verifyDevServerCrashHandledGracefully(t *testing.T, client *http.Client, tu
resp, respErr := client.Do(req)
if respErr != nil {
- log.Printf("Failed to get response: %v", respErr)
+ t.Errorf("Failed to get response: %v", respErr)
}
expectedBody := constants.LOCALHOST_NOT_RUNNING_ERR_TEXT
@@ -735,19 +771,45 @@ func verifyDevServerCrashHandledGracefully(t *testing.T, client *http.Client, tu
func TestSimulation(t *testing.T) {
simulationCtx, simulationCancel := context.WithCancel(context.Background())
- localDevServer := StartLocalDevServer()
+ // Start a local dev server with http
+ localDevServer := StartLocalDevServer("http", "localhost")
defer localDevServer.Close()
+ // Start a local dev server with https
+ localDevTLSServer := StartLocalDevServer("https", "example.com")
+ defer localDevTLSServer.Close()
+
+ // Write cert to file so we are able to pass it into mmar client
+ certErr := os.WriteFile("./temp-cert", localDevTLSServer.Certificate().Raw, 0644) // 0644 is file permissions
+ if certErr != nil {
+ log.Fatal(certErr)
+ }
+
go dnsserver.StartDnsServer()
go StartMmarServer(simulationCtx)
wait := time.NewTimer(2 * time.Second)
<-wait.C
- clientUrlCh := make(chan string)
- go StartMmarClient(simulationCtx, clientUrlCh, localDevServer.Port())
- // Wait for tunnel url
- tunnelUrl := <-clientUrlCh
+ // Start a basic mmar client
+ basicClientUrlCh := make(chan string)
+ go StartMmarClient(simulationCtx, basicClientUrlCh, localDevServer.Port(), "", "", "", "")
+
+ // Start another basic mmar client
+ basicClientUrlCh2 := make(chan string)
+ go StartMmarClient(simulationCtx, basicClientUrlCh2, localDevServer.Port(), "", "", "", "")
+
+ // Wait for all tunnel urls
+ mmarClientsCount := 2
+ tunnelUrls := []string{}
+ for range mmarClientsCount {
+ select {
+ case tunnelUrl := <-basicClientUrlCh:
+ tunnelUrls = append(tunnelUrls, tunnelUrl)
+ case tunnelUrl := <-basicClientUrlCh2:
+ tunnelUrls = append(tunnelUrls, tunnelUrl)
+ }
+ }
// Initialize http client
client := httpClient()
@@ -783,18 +845,27 @@ func TestSimulation(t *testing.T) {
verifyContentLengthWithNoBodyRequestHandled,
}
- for _, simTest := range simulationTests {
- wg.Add(1)
- go simTest(t, client, tunnelUrl, &wg)
- }
+ // Loop through all tunnel urls and run simulation tests
+ for _, tunnelUrl := range tunnelUrls {
+
+ for _, simTest := range simulationTests {
+ wg.Add(1)
+ go simTest(t, client, tunnelUrl, &wg)
+ }
- for _, manualClientSimTest := range manualClientSimulationTests {
- wg.Add(1)
- go manualClientSimTest(t, tunnelUrl, &wg)
+ for _, manualClientSimTest := range manualClientSimulationTests {
+ wg.Add(1)
+ go manualClientSimTest(t, tunnelUrl, &wg)
+ }
}
wg.Wait()
+ // Delete cert file
+ if rmErr := os.Remove("./temp-cert"); rmErr != nil {
+ log.Fatal(rmErr)
+ }
+
// Stop simulation tests
simulationCancel()
diff --git a/simulations/simulation_utils.go b/simulations/simulation_utils.go
index d6ae1bf..0a13fab 100644
--- a/simulations/simulation_utils.go
+++ b/simulations/simulation_utils.go
@@ -95,7 +95,8 @@ func httpClient() *http.Client {
dialer := initCustomDialer()
tp := &http.Transport{
- DialContext: dialer.DialContext,
+ DialContext: dialer.DialContext,
+ DisableKeepAlives: true,
}
client := &http.Client{Transport: tp}
return client