Compare commits

..

No commits in common. "master" and "v2.1.1" have entirely different histories.

97 changed files with 3728 additions and 7821 deletions

View File

@ -1,91 +0,0 @@
name: Build and test
on: [ push ]
jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ ubuntu-latest, macos-latest, windows-latest ]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '^1.24' # The Go version to download (if necessary) and use.
- run: go test -race -coverprofile coverage.txt -coverpkg ./... -covermode atomic ./...
- uses: codecov/codecov-action@v4
with:
files: coverage.txt
token: ${{ secrets.CODECOV_TOKEN }}
compat-test:
runs-on: ubuntu-latest
strategy:
matrix:
encryption-method: [ plain, chacha20-poly1305 ]
num-conn: [ 0, 1, 4 ]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '^1.24'
- name: Build Cloak
run: make
- name: Create configs
run: |
mkdir config
cat << EOF > config/ckclient.json
{
"Transport": "direct",
"ProxyMethod": "iperf",
"EncryptionMethod": "${{ matrix.encryption-method }}",
"UID": "Q4GAXHVgnDLXsdTpw6bmoQ==",
"PublicKey": "4dae/bF43FKGq+QbCc5P/E/MPM5qQeGIArjmJEHiZxc=",
"ServerName": "cloudflare.com",
"BrowserSig": "firefox",
"NumConn": ${{ matrix.num-conn }}
}
EOF
cat << EOF > config/ckserver.json
{
"ProxyBook": {
"iperf": [
"tcp",
"127.0.0.1:5201"
]
},
"BindAddr": [
":8443"
],
"BypassUID": [
"Q4GAXHVgnDLXsdTpw6bmoQ=="
],
"RedirAddr": "cloudflare.com",
"PrivateKey": "AAaskZJRPIAbiuaRLHsvZPvE6gzOeSjg+ZRg1ENau0Y="
}
EOF
- name: Start iperf3 server
run: docker run -d --name iperf-server --network host ajoergensen/iperf3:latest --server
- name: Test new client against old server
run: |
docker run -d --name old-cloak-server --network host -v $PWD/config:/go/Cloak/config cbeuw/cloak:latest build/ck-server -c config/ckserver.json --verbosity debug
build/ck-client -c config/ckclient.json -s 127.0.0.1 -p 8443 --verbosity debug | tee new-cloak-client.log &
docker run --network host ajoergensen/iperf3:latest --client 127.0.0.1 -p 1984
docker stop old-cloak-server
- name: Test old client against new server
run: |
build/ck-server -c config/ckserver.json --verbosity debug | tee new-cloak-server.log &
docker run -d --name old-cloak-client --network host -v $PWD/config:/go/Cloak/config cbeuw/cloak:latest build/ck-client -c config/ckclient.json -s 127.0.0.1 -p 8443 --verbosity debug
docker run --network host ajoergensen/iperf3:latest --client 127.0.0.1 -p 1984
docker stop old-cloak-client
- name: Dump docker logs
if: always()
run: |
docker container logs iperf-server > iperf-server.log
docker container logs old-cloak-server > old-cloak-server.log
docker container logs old-cloak-client > old-cloak-client.log
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.encryption-method }}-${{ matrix.num-conn }}-conn-logs
path: ./*.log

View File

@ -1,50 +0,0 @@
on:
push:
tags:
- 'v*'
name: Create Release
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build
run: |
export PATH=${PATH}:`go env GOPATH`/bin
v=${GITHUB_REF#refs/*/} ./release.sh
- name: Release
uses: softprops/action-gh-release@v1
with:
files: release/*
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
build-docker:
runs-on: ubuntu-latest
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: |
cbeuw/cloak
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v6
with:
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

6
.gitignore vendored
View File

@ -1,6 +0,0 @@
corpus/
suppressions/
crashers/
*.zip
.idea/
build/

View File

@ -1,5 +0,0 @@
FROM golang:latest
RUN git clone https://github.com/cbeuw/Cloak.git
WORKDIR Cloak
RUN make

236
README.md
View File

@ -1,239 +1,109 @@
[![Build Status](https://github.com/cbeuw/Cloak/workflows/Build%20and%20test/badge.svg)](https://github.com/cbeuw/Cloak/actions) ![image](https://user-images.githubusercontent.com/7034308/65361318-0a719180-dbfb-11e9-96de-56d1023856f0.png)
[![codecov](https://codecov.io/gh/cbeuw/Cloak/branch/master/graph/badge.svg)](https://codecov.io/gh/cbeuw/Cloak)
[![Go Report Card](https://goreportcard.com/badge/github.com/cbeuw/Cloak)](https://goreportcard.com/report/github.com/cbeuw/Cloak)
[![Donate](https://img.shields.io/badge/Donate-PayPal-green.svg)](https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=SAUYKGSREP8GL&source=url)
<p align="center"> ![Cloak](https://user-images.githubusercontent.com/7034308/64479678-1e0c0980-d1b2-11e9-836e-b4c1238f2669.png)
<img src="https://user-images.githubusercontent.com/7034308/96387206-3e214100-1198-11eb-8917-689d7c56e0cd.png" />
<img src="https://user-images.githubusercontent.com/7034308/155593583-f22bcfe2-ac22-4afb-9288-1a0e8a791a0d.svg" />
</p>
<p align="center"> Cloak is a universal pluggable transport that cryptographically obfuscates proxy traffic as legitimate HTTPS traffic, disguises the proxy server as a normal web server, multiplexes traffic through multiple TCP connections and provides multi-user usage control.
<img src="https://user-images.githubusercontent.com/7034308/155629720-54dd8758-ec98-4fed-b603-623f0ad83b6c.svg" />
</p>
Cloak is a [pluggable transport](https://datatracker.ietf.org/meeting/103/materials/slides-103-pearg-pt-slides-01) that enhances Cloak works fundamentally by masquerading proxy traffic as indistinguishable normal web browsing traffic. This increases the collateral damage to censorship actions and therefore make it very difficult, if not impossible, for censors to selectively block censorship evasion tools and proxy servers without affecting services that the state may also heavily rely on.
traditional proxy tools like OpenVPN to evade [sophisticated censorship](https://en.wikipedia.org/wiki/Deep_packet_inspection) and [data discrimination](https://en.wikipedia.org/wiki/Net_bias).
Cloak is not a standalone proxy program. Rather, it works by masquerading proxied traffic as normal web browsing Cloak eliminates any "fingerprints" exposed by traditional proxy protocol designs which can be identified by adversaries through deep packet inspection. If a non-Cloak program or an unauthorised Cloak user (such as an adversary's prober) attempts to connect to Cloak server, it will serve as a transparent proxy between said machine and an ordinary website, so that to any unauthorised third party, a host running Cloak server is indistinguishable from an innocent web server. This is achieved through the use a series of [cryptographic stegnatography techniques](https://github.com/cbeuw/Cloak/wiki/Steganography-and-encryption).
activities. In contrast to traditional tools which have very prominent traffic fingerprints and can be blocked by simple filtering rules,
it's very difficult to precisely target Cloak with little false positives. This increases the collateral damage to censorship actions as
attempts to block Cloak could also damage services the censor state relies on.
To any third party observer, a host running Cloak server is indistinguishable from an innocent web server. Both while Since Cloak is transparent, it can be used in conjunction with any proxy software that tunnels traffic through TCP, such as Shadowsocks, OpenVPN and Tor. Multiple proxy servers can be running on the same server host machine and Cloak server will act as a reverse proxy, bridging clients with their desired proxy end.
passively observing traffic flow to and from the server, as well as while actively probing the behaviours of a Cloak
server. This is achieved through the use a series
of [cryptographic steganography techniques](https://github.com/cbeuw/Cloak/wiki/Steganography-and-encryption).
Cloak can be used in conjunction with any proxy program that tunnels traffic through TCP or Cloak multiplexes traffic through multiple underlying TCP connections which reduces head-of-line blocking and eliminates TCP handshake overhead.
UDP, such as Shadowsocks, OpenVPN and Tor. Multiple proxy servers can be running on the same server host and
Cloak server will act as a reverse proxy, bridging clients with their desired proxy end.
Cloak multiplexes traffic through multiple underlying TCP connections which reduces head-of-line blocking and eliminates Cloak provides multi-user support, allowing multiple clients to connect to the proxy server on the same port (443 by default). It also provides QoS controls for individual users such as data usage limit and bandwidth control.
TCP handshake overhead. This also makes the traffic pattern more similar to real websites.
Cloak provides multi-user support, allowing multiple clients to connect to the proxy server on the same port (443 by Cloak has two modes of [_Transport_](https://github.com/cbeuw/Cloak/wiki/CDN-mode): `direct` and `CDN`. Clients can either connect to the host running Cloak server directly, or it can instead connect to a CDN edge server, which may be used by many legitimate websites as well, and thus increase the collateral damage to censorship.
default). It also provides traffic management features such as usage credit and bandwidth control. This allows a proxy
server to serve multiple users even if the underlying proxy software wasn't designed for multiple users
Cloak also supports tunneling through an intermediary CDN server such as Amazon Cloudfront. Such services are so widely used, **Cloak 2.x is not compatible with legacy Cloak 1.x's protocol, configuration file or database file. Cloak 1.x protocol has critical cryptographic flaws regarding encrypting stream headers. Using Cloak 1.x is strongly discouraged**
attempts to disrupt traffic to them can lead to very high collateral damage for the censor.
This project is based on [GoQuiet](https://github.com/cbeuw/GoQuiet). Through multiplexing, Cloak provides a siginifcant reduction in webpage loading time compared to GoQuiet (from 10% to 50%+, depending on the amount of content on the webpage, see [benchmarks](https://github.com/cbeuw/Cloak/wiki/Web-page-loading-benchmarks)).
## Quick Start ## Quick Start
To quickly deploy Cloak with Shadowsocks on a server, you can run this [script](https://github.com/HirbodBehnam/Shadowsocks-Cloak-Installer/blob/master/Cloak2-Installer.sh) written by @HirbodBehnam
To quickly deploy Cloak with Shadowsocks on a server, you can run
this [script](https://github.com/HirbodBehnam/Shadowsocks-Cloak-Installer/blob/master/Cloak2-Installer.sh) written by
@HirbodBehnam
Table of Contents
=================
* [Quick Start](#quick-start)
* [Build](#build)
* [Configuration](#configuration)
* [Server](#server)
* [Client](#client)
* [Setup](#setup)
* [Server](#server-1)
* [To add users](#to-add-users)
* [Unrestricted users](#unrestricted-users)
* [Users subject to bandwidth and credit controls](#users-subject-to-bandwidth-and-credit-controls)
* [Client](#client-1)
* [Support me](#support-me)
## Build ## Build
If you are not using the experimental go mod support, make sure you `go get` the following dependencies:
```bash
git clone https://github.com/cbeuw/Cloak
cd Cloak
go get ./...
make
``` ```
github.com/boltdb/bolt
Built binaries will be in `build` folder. github.com/juju/ratelimit
github.com/gorilla/mux
github.com/gorilla/websocket
github.com/sirupsen/logrus
golang.org/x/crypto
github.com/refraction-networking/utls
```
Then run `make client` or `make server`. Output binary will be in `build` folder.
## Configuration ## Configuration
Examples of configuration files can be found under `example_config` folder.
### Server ### Server
`RedirAddr` is the redirection address when the incoming traffic is not from a Cloak client. It should either be the same as, or correspond to the IP record of the `ServerName` field set in `ckclient.json`.
`RedirAddr` is the redirection address when the incoming traffic is not from a Cloak client. Ideally it should be set to `BindAddr` is a list of addresses Cloak will bind and listen to (e.g. [":443",":80] to listen to port 443 and 80 on all interfaces)
a major website allowed by the censor (e.g. `www.bing.com`)
`BindAddr` is a list of addresses Cloak will bind and listen to (e.g. `[":443",":80"]` to listen to port 443 and 80 on `ProxyBook` is a nested JSON section which defines the address of different proxy server ends. For instance, if OpenVPN server is listening on 127.0.0.1:1194, the pair should be `"openvpn":"127.0.0.1:1194"`. There can be multiple pairs. You can add any other proxy server in a similar fashion, as long as the name matches the `ProxyMethod` in the client config exactly (case-sensitive).
all interfaces)
`ProxyBook` is an object whose key is the name of the ProxyMethod used on the client-side (case-sensitive). Its value is
an array whose first element is the protocol, and the second element is an `IP:PORT` string of the upstream proxy server
that Cloak will forward the traffic to.
Example:
```json
{
"ProxyBook": {
"shadowsocks": [
"tcp",
"localhost:51443"
],
"openvpn": [
"tcp",
"localhost:12345"
]
}
}
```
`PrivateKey` is the static curve25519 Diffie-Hellman private key encoded in base64. `PrivateKey` is the static curve25519 Diffie-Hellman private key encoded in base64.
`AdminUID` is the UID of the admin user in base64.
`BypassUID` is a list of UIDs that are authorised without any bandwidth or credit limit restrictions `BypassUID` is a list of UIDs that are authorised without any bandwidth or credit limit restrictions
`AdminUID` is the UID of the admin user in base64. You can leave this empty if you only ever add users to `BypassUID`. `DatabasePath` is the path to userinfo.db. If userinfo.db doesn't exist in this directory, Cloak will create one automatically. **If Cloak is started as a Shadowsocks plugin and Shadowsocks is started with its working directory as / (e.g. starting ss-server with systemctl), you need to set this field as an absolute path to a desired folder. If you leave it as default then Cloak will attempt to create userinfo.db under /, which it doesn't have the permission to do so and will raise an error. See Issue #13.**
`DatabasePath` is the path to `userinfo.db`, which is used to store user usage information and restrictions. Cloak will
create the file automatically if it doesn't exist. You can leave this empty if you only ever add users to `BypassUID`.
This field also has no effect if `AdminUID` isn't a valid UID or is empty.
`KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the
upstream proxy server. Zero or negative value disables it. Default is 0 (disabled).
### Client ### Client
`UID` is your UID in base64. `UID` is your UID in base64.
`Transport` can be either `direct` or `CDN`. If the server host wishes you to connect to it directly, use `direct`. If `Transport` can be either `direct` or `CDN`. If the server host wishes you to connect to it directly, use `direct`. If instead a CDN is used, use `CDN`.
instead a CDN is used, use `CDN`.
`PublicKey` is the static curve25519 public key in base64, given by the server admin. `PublicKey` is the static curve25519 public key, given by the server admin.
`ProxyMethod` is the name of the proxy method you are using. This must match one of the entries in the `ProxyMethod` is the name of the proxy method you are using.
server's `ProxyBook` exactly.
`EncryptionMethod` is the name of the encryption algorithm you want Cloak to use. Options are `plain`, `aes-256-gcm` ( `EncryptionMethod` is the name of the encryption algorithm you want Cloak to use. Note: Cloak isn't intended to provide data encryption. The point of encryption is to hide fingerprints of proxy protocols. If the proxy protocol is already fingerprint-less, which is the case for Shadowsocks, this field can be left as `plain`. Options are `plain`, `aes-gcm` and `chacha20-poly1305`.
synonymous to `aes-gcm`), `aes-128-gcm`, and `chacha20-poly1305`. Note: Cloak isn't intended to provide transport
security. The point of encryption is to hide fingerprints of proxy protocols and render the payload statistically
random-like. **You may only leave it as `plain` if you are certain that your underlying proxy tool already provides BOTH
encryption and authentication (via AEAD or similar techniques).**
`ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should `ServerName` is the domain you want to make your ISP or firewall think you are visiting.
match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to. Use `random` to randomize the server name for every connection made.
`AlternativeNames` is an array used alongside `ServerName` to shuffle between different ServerNames for every new `NumConn` is the amount of underlying TCP connections you want to use. The default of 4 should be appropriate for most people. Setting it too high will hinder the performance.
connection. **This may conflict with `CDN` Transport mode** if the CDN provider prohibits domain fronting and rejects
the alternative domains.
Example: `BrowserSig` is the browser you want to **appear** to be using. It's not relevant to the browser you are actually using. Currently, `chrome` and `firefox` are supported.
```json
{
"ServerName": "bing.com",
"AlternativeNames": ["cloudflare.com", "github.com"]
}
```
`CDNOriginHost` is the domain name of the _origin_ server (i.e. the server running Cloak) under `CDN` mode. This only
has effect when `Transport` is set to `CDN`. If unset, it will default to the remote hostname supplied via the
commandline argument (in standalone mode), or by Shadowsocks (in plugin mode). After a TLS session is established with
the CDN server, this domain name will be used in the `Host` header of the HTTP request to ask the CDN server to
establish a WebSocket connection with this host.
`CDNWsUrlPath` is the url path used to build websocket request sent under `CDN` mode, and also only has effect
when `Transport` is set to `CDN`. If unset, it will default to "/". This option is used to build the first line of the
HTTP request after a TLS session is extablished. It's mainly for a Cloak server behind a reverse proxy, while only
requests under specific url path are forwarded.
`NumConn` is the amount of underlying TCP connections you want to use. The default of 4 should be appropriate for most
people. Setting it too high will hinder the performance. Setting it to 0 will disable connection multiplexing and each
TCP connection will spawn a separate short-lived session that will be closed after it is terminated. This makes it
behave like GoQuiet. This maybe useful for people with unstable connections.
`BrowserSig` is the browser you want to **appear** to be using. It's not relevant to the browser you are actually using.
Currently, `chrome`, `firefox` and `safari` are supported.
`KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the
Cloak server. Zero or negative value disables it. Default is 0 (disabled). Warning: Enabling it might make your server
more detectable as a proxy, but it will make the Cloak client detect internet interruption more quickly.
`StreamTimeout` is the number of seconds of Cloak waits for an incoming connection from a proxy program to send any
data, after which the connection will be closed by Cloak. Cloak will not enforce any timeout on TCP connections after it
is established.
## Setup ## Setup
### For the administrator of the server
### Server 0. Set up the underlying proxy server. Note that if you are using OpenVPN, you must change the protocol to TCP as Cloak does not support UDP
0. Install at least one underlying proxy server (e.g. OpenVPN, Shadowsocks).
1. Download [the latest release](https://github.com/cbeuw/Cloak/releases) or clone and build this repo. 1. Download [the latest release](https://github.com/cbeuw/Cloak/releases) or clone and build this repo.
2. Run `ck-server -key`. The **public** should be given to users, the **private** key should be kept secret. 2. Run ck-server -k. The base64 string before the comma is the **public** key to be given to users, the one after the comma is the **private** key to be kept secret
3. (Skip if you only want to add unrestricted users) Run `ck-server -uid`. The new UID will be used as `AdminUID`. 3. Run `ck-server -u`. This will be used as the AdminUID
4. Copy example_config/ckserver.json into a desired location. Change `PrivateKey` to the private key you just obtained; 4. Copy example_config/ckserver.json into a desired location. Change `PrivateKey` to the private key you just obtained; change `AdminUID` to the UID you just obtained.
change `AdminUID` to the UID you just obtained. 5. Configure your underlying proxy server so that they all listen on localhost. Edit `ProxyBook` in the configuration file accordingly
5. Configure your underlying proxy server so that they all listen on localhost. Edit `ProxyBook` in the configuration 6. [Configure the proxy program.](https://github.com/cbeuw/Cloak/wiki/Underlying-proxy-configuration-guides) Run `sudo ck-server -c <path to ckserver.json>`. ck-server needs root privilege because it binds to a low numbered port (443). Alternatively you can follow https://superuser.com/a/892391 to avoid granting ck-server root privilege unnecessarily.
file accordingly
6. [Configure the proxy program.](https://github.com/cbeuw/Cloak/wiki/Underlying-proxy-configuration-guides)
Run `sudo ck-server -c <path to ckserver.json>`. ck-server needs root privilege because it binds to a low numbered
port (443). Alternatively you can follow https://superuser.com/a/892391 to avoid granting ck-server root privilege
unnecessarily.
#### To add users #### To add users
##### Unrestricted users ##### Unrestricted users
Run `ck-server -u` and add the UID into the `BypassUID` field in `ckserver.json`
Run `ck-server -uid` and add the UID into the `BypassUID` field in `ckserver.json`
##### Users subject to bandwidth and credit controls ##### Users subject to bandwidth and credit controls
1. On your client, run `ck-client -s <IP of the server> -l <A local port> -a <AdminUID> -c <path-to-ckclient.json>` to enter admin mode
0. First make sure you have `AdminUID` generated and set in `ckserver.json`, along with a path to `userinfo.db` 2. Visit https://cbeuw.github.io/Cloak-panel (Note: this is a static site, there is no backend and all data entered into this site are processed between your browser and the Cloak API endpoint you specified. Alternatively you can download the repo at https://github.com/cbeuw/Cloak-panel and host it on your own web server).
in `DatabasePath` (Cloak will create this file for you if it didn't already exist). 3. Type in 127.0.0.1:<the port you entered in step 1> as the API Base, and click `List`.
1. On your client, run `ck-client -s <IP of the server> -l <A local port> -a <AdminUID> -c <path-to-ckclient.json>` to
enter admin mode
2. Visit https://cbeuw.github.io/Cloak-panel (Note: this is a pure-js static site, there is no backend and all data
entered into this site are processed between your browser and the Cloak API endpoint you specified. Alternatively you
can download the repo at https://github.com/cbeuw/Cloak-panel and open `index.html` in a browser. No web server is
required).
3. Type in `127.0.0.1:<the port you entered in step 1>` as the API Base, and click `List`.
4. You can add in more users by clicking the `+` panel 4. You can add in more users by clicking the `+` panel
Note: the user database is persistent as it's in-disk. You don't need to add the users again each time you start Note: the user database is persistent as it's in-disk. You don't need to add the users again each time you start ck-server.
ck-server.
### Client
### Instructions for clients
**Android client is available here: https://github.com/cbeuw/Cloak-android** **Android client is available here: https://github.com/cbeuw/Cloak-android**
0. Install the underlying proxy client corresponding to what the server has. 0. Install and configure the proxy client based on the server
1. Download [the latest release](https://github.com/cbeuw/Cloak/releases) or clone and build this repo. 1. Download [the latest release](https://github.com/cbeuw/Cloak/releases) or clone and build this repo.
2. Obtain the public key and your UID from the administrator of your server 2. Obtain the public key and your UID from the administrator of your server
3. Copy `example_config/ckclient.json` into a location of your choice. Enter the `UID` and `PublicKey` you have 3. Copy example_config/ckclient.json into a location of your choice. Enter the `UID` and `PublicKey` you have obtained. Set `ProxyMethod` to match exactly the corresponding entry in `ProxyBook` on the server end
obtained. Set `ProxyMethod` to match exactly the corresponding entry in `ProxyBook` on the server end 4. [Configure the proxy program.](https://github.com/cbeuw/Cloak/wiki/Underlying-proxy-configuration-guides) Run `ck-client -c <path to ckclient.json> -s <ip of your server>`
4. [Configure the proxy program.](https://github.com/cbeuw/Cloak/wiki/Underlying-proxy-configuration-guides)
Run `ck-client -c <path to ckclient.json> -s <ip of your server>`
## Support me ## Support me
If you find this project useful, you can visit my [merch store](https://teespring.com/en-GB/stores/andys-scribble) which sells some of my designed t-shirts, phone cases, mugs and other bits and bobs; alternatively you can donate directly to me
If you find this project useful, you can visit my [merch store](https://www.redbubble.com/people/cbeuw/explore);
alternatively you can donate directly to me
[![Donate](https://img.shields.io/badge/Donate-PayPal-green.svg)](https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=SAUYKGSREP8GL&source=url) [![Donate](https://img.shields.io/badge/Donate-PayPal-green.svg)](https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=SAUYKGSREP8GL&source=url)

View File

@ -1,25 +1,229 @@
//go:build go1.11
// +build go1.11 // +build go1.11
package main package main
import ( import (
"crypto/rand"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"flag" "flag"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"sync"
"github.com/cbeuw/Cloak/internal/common" "sync/atomic"
"time"
"github.com/cbeuw/Cloak/internal/client" "github.com/cbeuw/Cloak/internal/client"
mux "github.com/cbeuw/Cloak/internal/multiplex" mux "github.com/cbeuw/Cloak/internal/multiplex"
"github.com/cbeuw/Cloak/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var version string var version string
func makeSession(sta *client.State, isAdmin bool) *mux.Session {
log.Info("Attemtping to start a new session")
if !isAdmin {
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
// sessionID is limited to its UID.
quad := make([]byte, 4)
rand.Read(quad)
atomic.StoreUint32(&sta.SessionID, binary.BigEndian.Uint32(quad))
}
d := net.Dialer{Control: protector}
connsCh := make(chan net.Conn, sta.NumConn)
var _sessionKey atomic.Value
var wg sync.WaitGroup
for i := 0; i < sta.NumConn; i++ {
wg.Add(1)
go func() {
makeconn:
connectingIP := sta.RemoteHost
if net.ParseIP(connectingIP).To4() == nil {
// IPv6 needs square brackets
connectingIP = "[" + connectingIP + "]"
}
remoteConn, err := d.Dial("tcp", connectingIP+":"+sta.RemotePort)
if err != nil {
log.Errorf("Failed to establish new connections to remote: %v", err)
// TODO increase the interval if failed multiple times
time.Sleep(time.Second * 3)
goto makeconn
}
var sk []byte
remoteConn, sk, err = sta.Transport.PrepareConnection(sta, remoteConn)
if err != nil {
remoteConn.Close()
log.Errorf("Failed to prepare connection to remote: %v", err)
time.Sleep(time.Second * 3)
goto makeconn
}
_sessionKey.Store(sk)
connsCh <- remoteConn
wg.Done()
}()
}
wg.Wait()
log.Debug("All underlying connections established")
sessionKey := _sessionKey.Load().([]byte)
obfuscator, err := mux.GenerateObfs(sta.EncryptionMethod, sessionKey, sta.Transport.HasRecordLayer())
if err != nil {
log.Fatal(err)
}
seshConfig := &mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
UnitRead: sta.Transport.UnitReadFunc(),
Unordered: sta.Unordered,
}
sesh := mux.MakeSession(sta.SessionID, seshConfig)
for i := 0; i < sta.NumConn; i++ {
conn := <-connsCh
sesh.AddConnection(conn)
}
log.Infof("Session %v established", sta.SessionID)
return sesh
}
func routeUDP(sta *client.State, adminUID []byte) {
var sesh *mux.Session
localUDPAddr, err := net.ResolveUDPAddr("udp", sta.LocalHost+":"+sta.LocalPort)
if err != nil {
log.Fatal(err)
}
start:
localConn, err := net.ListenUDP("udp", localUDPAddr)
if err != nil {
log.Fatal(err)
}
var otherEnd atomic.Value
data := make([]byte, 10240)
i, oe, err := localConn.ReadFromUDP(data)
if err != nil {
log.Errorf("Failed to read first packet from proxy client: %v", err)
localConn.Close()
return
}
otherEnd.Store(oe)
if sesh == nil || sesh.IsClosed() {
sesh = makeSession(sta, adminUID != nil)
}
log.Debugf("proxy local address %v", otherEnd.Load().(*net.UDPAddr).String())
stream, err := sesh.OpenStream()
if err != nil {
log.Errorf("Failed to open stream: %v", err)
localConn.Close()
//localConnWrite.Close()
return
}
_, err = stream.Write(data[:i])
if err != nil {
log.Errorf("Failed to write to stream: %v", err)
localConn.Close()
//localConnWrite.Close()
stream.Close()
return
}
// stream to proxy
go func() {
buf := make([]byte, 16380)
for {
i, err := io.ReadAtLeast(stream, buf, 1)
if err != nil {
log.Print(err)
localConn.Close()
stream.Close()
break
}
i, err = localConn.WriteToUDP(buf[:i], otherEnd.Load().(*net.UDPAddr))
if err != nil {
log.Print(err)
localConn.Close()
stream.Close()
break
}
}
}()
// proxy to stream
buf := make([]byte, 16380)
if sta.Timeout != 0 {
localConn.SetReadDeadline(time.Now().Add(sta.Timeout))
}
for {
if sta.Timeout != 0 {
localConn.SetReadDeadline(time.Now().Add(sta.Timeout))
}
i, oe, err := localConn.ReadFromUDP(buf)
if err != nil {
localConn.Close()
stream.Close()
break
}
otherEnd.Store(oe)
i, err = stream.Write(buf[:i])
if err != nil {
localConn.Close()
stream.Close()
break
}
}
goto start
}
func routeTCP(sta *client.State, adminUID []byte) {
tcpListener, err := net.Listen("tcp", sta.LocalHost+":"+sta.LocalPort)
if err != nil {
log.Fatal(err)
}
var sesh *mux.Session
for {
localConn, err := tcpListener.Accept()
if err != nil {
log.Fatal(err)
continue
}
if sesh == nil || sesh.IsClosed() {
sesh = makeSession(sta, adminUID != nil)
}
go func() {
data := make([]byte, 10240)
i, err := io.ReadAtLeast(localConn, data, 1)
if err != nil {
log.Errorf("Failed to read first packet from proxy client: %v", err)
localConn.Close()
return
}
stream, err := sesh.OpenStream()
if err != nil {
log.Errorf("Failed to open stream: %v", err)
localConn.Close()
return
}
_, err = stream.Write(data[:i])
if err != nil {
log.Errorf("Failed to write to stream: %v", err)
localConn.Close()
stream.Close()
return
}
go util.Pipe(localConn, stream, 0)
util.Pipe(stream, localConn, sta.Timeout)
}()
}
}
func main() { func main() {
// Should be 127.0.0.1 to listen to a proxy client on this machine // Should be 127.0.0.1 to listen to a proxy client on this machine
var localHost string var localHost string
@ -33,33 +237,28 @@ func main() {
var udp bool var udp bool
var config string var config string
var b64AdminUID string var b64AdminUID string
var vpnMode bool
var tcpFastOpen bool
log_init() log_init()
log.SetLevel(log.DebugLevel)
ssPluginMode := os.Getenv("SS_LOCAL_HOST") != "" if os.Getenv("SS_LOCAL_HOST") != "" {
localHost = os.Getenv("SS_LOCAL_HOST")
verbosity := flag.String("verbosity", "info", "verbosity level") localPort = os.Getenv("SS_LOCAL_PORT")
if ssPluginMode { remoteHost = os.Getenv("SS_REMOTE_HOST")
remotePort = os.Getenv("SS_REMOTE_PORT")
config = os.Getenv("SS_PLUGIN_OPTIONS") config = os.Getenv("SS_PLUGIN_OPTIONS")
flag.BoolVar(&vpnMode, "V", false, "ignored.")
flag.BoolVar(&tcpFastOpen, "fast-open", false, "ignored.")
flag.Parse() // for verbosity only
} else { } else {
flag.StringVar(&localHost, "i", "127.0.0.1", "localHost: Cloak listens to proxy clients on this ip") flag.StringVar(&localHost, "i", "127.0.0.1", "localHost: Cloak listens to proxy clients on this ip")
flag.StringVar(&localPort, "l", "1984", "localPort: Cloak listens to proxy clients on this port") flag.StringVar(&localPort, "l", "1984", "localPort: Cloak listens to proxy clients on this port")
flag.StringVar(&remoteHost, "s", "", "remoteHost: IP of your proxy server") flag.StringVar(&remoteHost, "s", "", "remoteHost: IP of your proxy server")
flag.StringVar(&remotePort, "p", "443", "remotePort: proxy port, should be 443") flag.StringVar(&remotePort, "p", "443", "remotePort: proxy port, should be 443")
flag.BoolVar(&udp, "u", false, "udp: set this flag if the underlying proxy is using UDP protocol") flag.BoolVar(&udp, "u", false, "udp: set this flag if the underlying proxy is using UDP protocol")
flag.StringVar(&config, "c", "ckclient.json", "config: path to the configuration file or options separated with semicolons") flag.StringVar(&config, "c", "ckclient.json", "config: path to the configuration file or options seperated with semicolons")
flag.StringVar(&proxyMethod, "proxy", "", "proxy: the proxy method's name. It must match exactly with the corresponding entry in server's ProxyBook") flag.StringVar(&proxyMethod, "proxy", "", "proxy: the proxy method's name. It must match exactly with the corresponding entry in server's ProxyBook")
flag.StringVar(&b64AdminUID, "a", "", "adminUID: enter the adminUID to serve the admin api") flag.StringVar(&b64AdminUID, "a", "", "adminUID: enter the adminUID to serve the admin api")
askVersion := flag.Bool("v", false, "Print the version number") askVersion := flag.Bool("v", false, "Print the version number")
printUsage := flag.Bool("h", false, "Print this message") printUsage := flag.Bool("h", false, "Print this message")
verbosity := flag.String("verbosity", "info", "verbosity level")
// commandline arguments overrides json
flag.Parse() flag.Parse()
if *askVersion { if *askVersion {
@ -72,76 +271,47 @@ func main() {
return return
} }
log.Info("Starting standalone mode")
}
log.SetFormatter(&log.TextFormatter{
FullTimestamp: true,
})
lvl, err := log.ParseLevel(*verbosity) lvl, err := log.ParseLevel(*verbosity)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
log.SetLevel(lvl) log.SetLevel(lvl)
rawConfig, err := client.ParseConfig(config) log.Info("Starting standalone mode")
}
sta := &client.State{
LocalHost: localHost,
LocalPort: localPort,
RemoteHost: remoteHost,
RemotePort: remotePort,
Now: time.Now,
}
err := sta.ParseConfig(config)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if ssPluginMode { if proxyMethod != "" {
if rawConfig.ProxyMethod == "" { sta.ProxyMethod = proxyMethod
rawConfig.ProxyMethod = "shadowsocks"
}
// json takes precedence over environment variables
// i.e. if json field isn't empty, use that
if rawConfig.RemoteHost == "" {
rawConfig.RemoteHost = os.Getenv("SS_REMOTE_HOST")
}
if rawConfig.RemotePort == "" {
rawConfig.RemotePort = os.Getenv("SS_REMOTE_PORT")
}
if rawConfig.LocalHost == "" {
rawConfig.LocalHost = os.Getenv("SS_LOCAL_HOST")
}
if rawConfig.LocalPort == "" {
rawConfig.LocalPort = os.Getenv("SS_LOCAL_PORT")
}
} else {
// commandline argument takes precedence over json
// if commandline argument is set, use commandline
flag.Visit(func(f *flag.Flag) {
// manually set ones
switch f.Name {
case "i":
rawConfig.LocalHost = localHost
case "l":
rawConfig.LocalPort = localPort
case "s":
rawConfig.RemoteHost = remoteHost
case "p":
rawConfig.RemotePort = remotePort
case "u":
rawConfig.UDP = udp
case "proxy":
rawConfig.ProxyMethod = proxyMethod
}
})
// ones with default values
if rawConfig.LocalHost == "" {
rawConfig.LocalHost = localHost
}
if rawConfig.LocalPort == "" {
rawConfig.LocalPort = localPort
}
if rawConfig.RemotePort == "" {
rawConfig.RemotePort = remotePort
}
} }
localConfig, remoteConfig, authInfo, err := rawConfig.ProcessRawConfig(common.RealWorldState) if os.Getenv("SS_LOCAL_HOST") != "" {
if err != nil { sta.ProxyMethod = "shadowsocks"
log.Fatal(err) }
if sta.LocalPort == "" {
log.Fatal("Must specify localPort")
}
if sta.RemoteHost == "" {
log.Fatal("Must specify remoteHost")
}
listeningIP := sta.LocalHost
if net.ParseIP(listeningIP).To4() == nil {
// IPv6 needs square brackets
listeningIP = "[" + listeningIP + "]"
} }
var adminUID []byte var adminUID []byte
@ -152,55 +322,26 @@ func main() {
} }
} }
var seshMaker func() *mux.Session
d := &net.Dialer{Control: protector, KeepAlive: remoteConfig.KeepAlive}
if adminUID != nil { if adminUID != nil {
log.Infof("API base is %v", localConfig.LocalAddr) log.Infof("API base is %v:%v", listeningIP, sta.LocalPort)
authInfo.UID = adminUID sta.SessionID = 0
authInfo.SessionId = 0 sta.UID = adminUID
remoteConfig.NumConn = 1 sta.NumConn = 1
seshMaker = func() *mux.Session {
return client.MakeSession(remoteConfig, authInfo, d)
}
} else { } else {
var network string var network string
if authInfo.Unordered { if udp {
network = "UDP" network = "UDP"
sta.Unordered = true
} else { } else {
network = "TCP" network = "TCP"
sta.Unordered = false
} }
log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod) log.Infof("Listening on %v %v:%v for %v client", network, listeningIP, sta.LocalPort, sta.ProxyMethod)
seshMaker = func() *mux.Session {
authInfo := authInfo // copy the struct because we are overwriting SessionId
randByte := make([]byte, 1)
common.RandRead(authInfo.WorldState.Rand, randByte)
authInfo.MockDomain = localConfig.MockDomainList[int(randByte[0])%len(localConfig.MockDomainList)]
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
// sessionID is limited to its UID.
quad := make([]byte, 4)
common.RandRead(authInfo.WorldState.Rand, quad)
authInfo.SessionId = binary.BigEndian.Uint32(quad)
return client.MakeSession(remoteConfig, authInfo, d)
}
} }
if authInfo.Unordered { if udp {
acceptor := func() (*net.UDPConn, error) { routeUDP(sta, adminUID)
udpAddr, _ := net.ResolveUDPAddr("udp", localConfig.LocalAddr)
return net.ListenUDP("udp", udpAddr)
}
client.RouteUDP(acceptor, localConfig.Timeout, remoteConfig.Singleplex, seshMaker)
} else { } else {
listener, err := net.Listen("tcp", localConfig.LocalAddr) routeTCP(sta, adminUID)
if err != nil {
log.Fatal(err)
}
client.RouteTCP(listener, localConfig.Timeout, remoteConfig.Singleplex, seshMaker)
} }
} }

View File

@ -1,4 +1,3 @@
//go:build !android
// +build !android // +build !android
package main package main

View File

@ -2,7 +2,6 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build android
// +build android // +build android
package main package main
@ -29,10 +28,9 @@ import "C"
import ( import (
"bufio" "bufio"
log "github.com/sirupsen/logrus"
"os" "os"
"unsafe" "unsafe"
log "github.com/sirupsen/logrus"
) )
var ( var (

View File

@ -1,4 +1,3 @@
//go:build !android
// +build !android // +build !android
package main package main

View File

@ -1,6 +1,4 @@
//go:build android
// +build android // +build android
package main package main
// Stolen from https://github.com/shadowsocks/overture/blob/shadowsocks/core/utils/utils_android.go // Stolen from https://github.com/shadowsocks/overture/blob/shadowsocks/core/utils/utils_android.go
@ -21,8 +19,9 @@ package main
int fd[n]; \ int fd[n]; \
} }
int ancil_send_fds_with_buffer(int sock, const int *fds, unsigned n_fds, int
void *buffer) { ancil_send_fds_with_buffer(int sock, const int *fds, unsigned n_fds, void *buffer)
{
struct msghdr msghdr; struct msghdr msghdr;
char nothing = '!'; char nothing = '!';
struct iovec nothing_ptr; struct iovec nothing_ptr;
@ -42,33 +41,34 @@ int ancil_send_fds_with_buffer(int sock, const int *fds, unsigned n_fds,
cmsg->cmsg_len = msghdr.msg_controllen; cmsg->cmsg_len = msghdr.msg_controllen;
cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS; cmsg->cmsg_type = SCM_RIGHTS;
for (i = 0; i < n_fds; i++) for(i = 0; i < n_fds; i++)
((int *)CMSG_DATA(cmsg))[i] = fds[i]; ((int *)CMSG_DATA(cmsg))[i] = fds[i];
return (sendmsg(sock, &msghdr, 0) >= 0 ? 0 : -1); return(sendmsg(sock, &msghdr, 0) >= 0 ? 0 : -1);
} }
int ancil_send_fd(int sock, int fd) { int
ancil_send_fd(int sock, int fd)
{
ANCIL_FD_BUFFER(1) buffer; ANCIL_FD_BUFFER(1) buffer;
return (ancil_send_fds_with_buffer(sock, &fd, 1, &buffer)); return(ancil_send_fds_with_buffer(sock, &fd, 1, &buffer));
} }
void set_timeout(int sock) { void
set_timeout(int sock)
{
struct timeval tv; struct timeval tv;
tv.tv_sec = 3; tv.tv_sec = 3;
tv.tv_usec = 0; tv.tv_usec = 0;
setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv, setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv, sizeof(struct timeval));
sizeof(struct timeval)); setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&tv, sizeof(struct timeval));
setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&tv, }
sizeof(struct timeval));
}
*/ */
import "C" import "C"
import ( import (
"syscall"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"syscall"
) )
// In Android, once an app starts the VpnService, all outgoing traffic are routed by the system // In Android, once an app starts the VpnService, all outgoing traffic are routed by the system

View File

@ -1,95 +1,212 @@
package main package main
import ( import (
"bytes"
"crypto/rand"
"encoding/base64"
"flag" "flag"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"runtime" "runtime"
"strings" "strings"
"time"
"github.com/cbeuw/Cloak/internal/common" mux "github.com/cbeuw/Cloak/internal/multiplex"
"github.com/cbeuw/Cloak/internal/server" "github.com/cbeuw/Cloak/internal/server"
"github.com/cbeuw/Cloak/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var b64 = base64.StdEncoding.EncodeToString
var version string var version string
func resolveBindAddr(bindAddrs []string) ([]net.Addr, error) { func dispatchConnection(conn net.Conn, sta *server.State) {
var addrs []net.Addr remoteAddr := conn.RemoteAddr()
for _, addr := range bindAddrs { var err error
bindAddr, err := net.ResolveTCPAddr("tcp", addr) buf := make([]byte, 1500)
if err != nil {
return nil, err
}
addrs = append(addrs, bindAddr)
}
return addrs, nil
}
// parse what shadowsocks server wants us to bind and harmonise it with what's already in bindAddr from // TODO: potential fingerprint for active probers here
// our own config's BindAddr. This prevents duplicate bindings etc. conn.SetReadDeadline(time.Now().Add(3 * time.Second))
func parseSSBindAddr(ssRemoteHost string, ssRemotePort string, ckBindAddr *[]net.Addr) error { i, err := io.ReadAtLeast(conn, buf, 1)
var ssBind string if err != nil {
// When listening on an IPv6 and IPv4, SS gives REMOTE_HOST as e.g. ::|0.0.0.0 go conn.Close()
v4nv6 := len(strings.Split(ssRemoteHost, "|")) == 2 return
if v4nv6 { }
ssBind = ":" + ssRemotePort conn.SetReadDeadline(time.Time{})
data := buf[:i]
goWeb := func() {
_, remotePort, _ := net.SplitHostPort(conn.LocalAddr().String())
webConn, err := net.Dial("tcp", net.JoinHostPort(sta.RedirAddr.String(), remotePort))
if err != nil {
log.Errorf("Making connection to redirection server: %v", err)
return
}
_, err = webConn.Write(data)
if err != nil {
log.Error("Failed to send first packet to redirection server", err)
}
go util.Pipe(webConn, conn, 0)
go util.Pipe(conn, webConn, 0)
}
ci, finishHandshake, err := server.PrepareConnection(data, sta, conn)
if err != nil {
log.WithFields(log.Fields{
"remoteAddr": remoteAddr,
"UID": b64(ci.UID),
"sessionId": ci.SessionId,
"proxyMethod": ci.ProxyMethod,
"encryptionMethod": ci.EncryptionMethod,
}).Warn(err)
goWeb()
return
}
sessionKey := make([]byte, 32)
rand.Read(sessionKey)
obfuscator, err := mux.GenerateObfs(ci.EncryptionMethod, sessionKey, ci.Transport.HasRecordLayer())
if err != nil {
log.Error(err)
goWeb()
return
}
// adminUID can use the server as normal with unlimited QoS credits. The adminUID is not
// added to the userinfo database. The distinction between going into the admin mode
// and normal proxy mode is that sessionID needs == 0 for admin mode
if bytes.Equal(ci.UID, sta.AdminUID) && ci.SessionId == 0 {
preparedConn, err := finishHandshake(sessionKey)
if err != nil {
log.Error(err)
return
}
log.Trace("finished handshake")
seshConfig := &mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
UnitRead: ci.Transport.UnitReadFunc(),
}
sesh := mux.MakeSession(0, seshConfig)
sesh.AddConnection(preparedConn)
//TODO: Router could be nil in cnc mode
log.WithField("remoteAddr", preparedConn.RemoteAddr()).Info("New admin session")
err = http.Serve(sesh, sta.LocalAPIRouter)
if err != nil {
log.Error(err)
return
}
}
var user *server.ActiveUser
if sta.IsBypass(ci.UID) {
user, err = sta.Panel.GetBypassUser(ci.UID)
} else { } else {
ssBind = net.JoinHostPort(ssRemoteHost, ssRemotePort) user, err = sta.Panel.GetUser(ci.UID)
} }
ssBindAddr, err := net.ResolveTCPAddr("tcp", ssBind)
if err != nil { if err != nil {
return fmt.Errorf("unable to resolve bind address provided by SS: %v", err) log.WithFields(log.Fields{
"UID": b64(ci.UID),
"remoteAddr": remoteAddr,
"error": err,
}).Warn("+1 unauthorised UID")
goWeb()
return
} }
shouldAppend := true seshConfig := &mux.SessionConfig{
for i, addr := range *ckBindAddr { Obfuscator: obfuscator,
if addr.String() == ssBindAddr.String() { Valve: nil,
shouldAppend = false UnitRead: ci.Transport.UnitReadFunc(),
Unordered: ci.Unordered,
} }
if addr.String() == ":"+ssRemotePort { // already listening on all interfaces sesh, existing, err := user.GetSession(ci.SessionId, seshConfig)
shouldAppend = false if err != nil {
user.CloseSession(ci.SessionId, "")
log.Error(err)
return
} }
if addr.String() == "0.0.0.0:"+ssRemotePort || addr.String() == "[::]:"+ssRemotePort {
// if config listens on one ip version but ss wants to listen on both, if existing {
// listen on both preparedConn, err := finishHandshake(sesh.SessionKey)
if ssBindAddr.String() == ":"+ssRemotePort { if err != nil {
shouldAppend = true log.Error(err)
(*ckBindAddr)[i] = ssBindAddr return
}
log.Trace("finished handshake")
sesh.AddConnection(preparedConn)
return
}
preparedConn, err := finishHandshake(sessionKey)
if err != nil {
log.Error(err)
return
}
log.Trace("finished handshake")
log.WithFields(log.Fields{
"UID": b64(ci.UID),
"sessionID": ci.SessionId,
}).Info("New session")
sesh.AddConnection(preparedConn)
for {
newStream, err := sesh.Accept()
if err != nil {
if err == mux.ErrBrokenSession {
log.WithFields(log.Fields{
"UID": b64(ci.UID),
"sessionID": ci.SessionId,
"reason": sesh.TerminalMsg(),
}).Info("Session closed")
user.CloseSession(ci.SessionId, "")
return
} else {
continue
} }
} }
proxyAddr := sta.ProxyBook[ci.ProxyMethod]
localConn, err := net.Dial(proxyAddr.Network(), proxyAddr.String())
if err != nil {
log.Errorf("Failed to connect to %v: %v", ci.ProxyMethod, err)
user.CloseSession(ci.SessionId, "Failed to connect to proxy server")
continue
} }
if shouldAppend { log.Tracef("%v endpoint has been successfully connected", ci.ProxyMethod)
*ckBindAddr = append(*ckBindAddr, ssBindAddr)
go util.Pipe(localConn, newStream, 0)
go util.Pipe(newStream, localConn, sta.Timeout)
} }
return nil
} }
func main() { func main() {
// set TLS bind host through commandline for legacy support, default 0.0.0,0
var ssRemoteHost string
// set TLS bind port through commandline for legacy support, default 443
var ssRemotePort string
var config string var config string
var pluginMode bool var pluginMode bool
log.SetFormatter(&log.TextFormatter{
FullTimestamp: true,
})
if os.Getenv("SS_LOCAL_HOST") != "" && os.Getenv("SS_LOCAL_PORT") != "" { if os.Getenv("SS_LOCAL_HOST") != "" && os.Getenv("SS_LOCAL_PORT") != "" {
pluginMode = true pluginMode = true
ssRemoteHost = os.Getenv("SS_REMOTE_HOST")
ssRemotePort = os.Getenv("SS_REMOTE_PORT")
config = os.Getenv("SS_PLUGIN_OPTIONS") config = os.Getenv("SS_PLUGIN_OPTIONS")
} else { } else {
flag.StringVar(&config, "c", "server.json", "config: path to the configuration file or its content") flag.StringVar(&config, "c", "server.json", "config: path to the configuration file or its content")
askVersion := flag.Bool("v", false, "Print the version number") askVersion := flag.Bool("v", false, "Print the version number")
printUsage := flag.Bool("h", false, "Print this message") printUsage := flag.Bool("h", false, "Print this message")
genUIDScript := flag.Bool("u", false, "Generate a UID to STDOUT") genUID := flag.Bool("u", false, "Generate a UID")
genKeyPairScript := flag.Bool("k", false, "Generate a pair of public and private key and output to STDOUT in the format of <public key>,<private key>") genKeyPair := flag.Bool("k", false, "Generate a pair of public and private key, output in the format of pubkey,pvkey")
genUIDHuman := flag.Bool("uid", false, "Generate and print out a UID")
genKeyPairHuman := flag.Bool("key", false, "Generate and print out a public-private key pair")
pprofAddr := flag.String("d", "", "debug use: ip:port to be listened by pprof profiler") pprofAddr := flag.String("d", "", "debug use: ip:port to be listened by pprof profiler")
verbosity := flag.String("verbosity", "info", "verbosity level") verbosity := flag.String("verbosity", "info", "verbosity level")
@ -104,23 +221,13 @@ func main() {
flag.Usage() flag.Usage()
return return
} }
if *genUIDScript || *genUIDHuman { if *genUID {
uid := generateUID() fmt.Println(generateUID())
if *genUIDScript {
fmt.Println(uid)
} else {
fmt.Printf("\x1B[35mYour UID is:\u001B[0m %s\n", uid)
}
return return
} }
if *genKeyPairScript || *genKeyPairHuman { if *genKeyPair {
pub, pv := generateKeyPair() pub, pv := generateKeyPair()
if *genKeyPairScript { fmt.Printf("%v,%v", pub, pv)
fmt.Printf("%v,%v\n", pub, pv)
} else {
fmt.Printf("\x1B[36mYour PUBLIC key is:\x1B[0m %65s\n", pub)
fmt.Printf("\x1B[33mYour PRIVATE key is (keep it secret):\x1B[0m %47s\n", pv)
}
return return
} }
@ -141,42 +248,63 @@ func main() {
log.Infof("Starting standalone mode") log.Infof("Starting standalone mode")
} }
sta, _ := server.InitState(time.Now)
raw, err := server.ParseConfig(config) err := sta.ParseConfig(config)
if err != nil { if err != nil {
log.Fatalf("Configuration file error: %v", err) log.Fatalf("Configuration file error: %v", err)
} }
bindAddr, err := resolveBindAddr(raw.BindAddr) if !pluginMode && len(sta.BindAddr) == 0 {
if err != nil {
log.Fatalf("unable to parse BindAddr: %v", err)
}
// in case the user hasn't specified any local address to bind to, we listen on 443 and 80
if !pluginMode && len(bindAddr) == 0 {
https, _ := net.ResolveTCPAddr("tcp", ":443") https, _ := net.ResolveTCPAddr("tcp", ":443")
http, _ := net.ResolveTCPAddr("tcp", ":80") http, _ := net.ResolveTCPAddr("tcp", ":80")
bindAddr = []net.Addr{https, http} sta.BindAddr = []net.Addr{https, http}
log.Fatalf("BindAddr cannot be empty")
} }
// when cloak is started as a shadowsocks plugin, we parse the address ss-server // when cloak is started as a shadowsocks plugin
// is listening on into ProxyBook, and we parse the list of bindAddr
if pluginMode { if pluginMode {
ssLocalHost := os.Getenv("SS_LOCAL_HOST") ssLocalHost := os.Getenv("SS_LOCAL_HOST")
ssLocalPort := os.Getenv("SS_LOCAL_PORT") ssLocalPort := os.Getenv("SS_LOCAL_PORT")
raw.ProxyBook["shadowsocks"] = []string{"tcp", net.JoinHostPort(ssLocalHost, ssLocalPort)}
ssRemoteHost := os.Getenv("SS_REMOTE_HOST") sta.ProxyBook["shadowsocks"], err = net.ResolveTCPAddr("tcp", net.JoinHostPort(ssLocalHost, ssLocalPort))
ssRemotePort := os.Getenv("SS_REMOTE_PORT")
err = parseSSBindAddr(ssRemoteHost, ssRemotePort, &bindAddr)
if err != nil { if err != nil {
log.Fatalf("failed to parse SS_REMOTE_HOST and SS_REMOTE_PORT: %v", err) log.Fatal(err)
}
} }
sta, err := server.InitState(raw, common.RealWorldState) var ssBind string
// When listening on an IPv6 and IPv4, SS gives REMOTE_HOST as e.g. ::|0.0.0.0
v4nv6 := len(strings.Split(ssRemoteHost, "|")) == 2
if v4nv6 {
ssBind = ":" + ssRemotePort
} else {
ssBind = net.JoinHostPort(ssRemoteHost, ssRemotePort)
}
ssBindAddr, err := net.ResolveTCPAddr("tcp", ssBind)
if err != nil { if err != nil {
log.Fatalf("unable to initialise server state: %v", err) log.Fatalf("unable to resolve bind address provided by SS: %v", err)
}
shouldAppend := true
for i, addr := range sta.BindAddr {
if addr.String() == ssBindAddr.String() {
shouldAppend = false
}
if addr.String() == ":"+ssRemotePort { // already listening on all interfaces
shouldAppend = false
}
if addr.String() == "0.0.0.0:"+ssRemotePort || addr.String() == "[::]:"+ssRemotePort {
// if config listens on one ip version but ss wants to listen on both,
// listen on both
if ssBindAddr.String() == ":"+ssRemotePort {
shouldAppend = true
sta.BindAddr[i] = ssBindAddr
}
}
}
if shouldAppend {
sta.BindAddr = append(sta.BindAddr, ssBindAddr)
}
} }
listen := func(bindAddr net.Addr) { listen := func(bindAddr net.Addr) {
@ -185,14 +313,20 @@ func main() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
server.Serve(listener, sta) for {
conn, err := listener.Accept()
if err != nil {
log.Errorf("%v", err)
continue
}
go dispatchConnection(conn, sta)
}
} }
for i, addr := range bindAddr { for i, addr := range sta.BindAddr {
if i != len(bindAddr)-1 { if i != len(sta.BindAddr)-1 {
go listen(addr) go listen(addr)
} else { } else {
// we block the main goroutine here so it doesn't quit
listen(addr) listen(addr)
} }
} }

View File

@ -1,136 +0,0 @@
package main
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseBindAddr(t *testing.T) {
t.Run("port only", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{":443"})
assert.NoError(t, err)
assert.Equal(t, ":443", addrs[0].String())
})
t.Run("specific address", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{"192.168.1.123:443"})
assert.NoError(t, err)
assert.Equal(t, "192.168.1.123:443", addrs[0].String())
})
t.Run("ipv6", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{"[::]:443"})
assert.NoError(t, err)
assert.Equal(t, "[::]:443", addrs[0].String())
})
t.Run("mixed", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{":80", "[::]:443"})
assert.NoError(t, err)
assert.Equal(t, ":80", addrs[0].String())
assert.Equal(t, "[::]:443", addrs[1].String())
})
}
func assertSetEqual(t *testing.T, list1, list2 interface{}, msgAndArgs ...interface{}) (ok bool) {
return assert.Subset(t, list1, list2, msgAndArgs) && assert.Subset(t, list2, list1, msgAndArgs)
}
func TestParseSSBindAddr(t *testing.T) {
testTable := []struct {
name string
ssRemoteHost string
ssRemotePort string
ckBindAddr []net.Addr
expectedAddr []net.Addr
}{
{
"ss only ipv4",
"127.0.0.1",
"443",
[]net.Addr{},
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 443,
},
},
},
{
"ss only ipv6",
"::",
"443",
[]net.Addr{},
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("::"),
Port: 443,
},
},
},
//{
// "ss only ipv4 and v6",
// "::|127.0.0.1",
// "443",
// []net.Addr{},
// []net.Addr{
// &net.TCPAddr{
// IP: net.ParseIP("::"),
// Port: 443,
// },
// &net.TCPAddr{
// IP: net.ParseIP("127.0.0.1"),
// Port: 443,
// },
// },
//},
{
"ss and existing agrees",
"::",
"443",
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("::"),
Port: 443,
},
},
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("::"),
Port: 443,
},
},
},
{
"ss adds onto existing",
"127.0.0.1",
"80",
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("::"),
Port: 443,
},
},
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("::"),
Port: 443,
},
&net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 80,
},
},
},
}
for _, test := range testTable {
test := test
t.Run(test.name, func(t *testing.T) {
assert.NoError(t, parseSSBindAddr(test.ssRemoteHost, test.ssRemotePort, &test.ckBindAddr))
assertSetEqual(t, test.ckBindAddr, test.expectedAddr)
})
}
}

View File

@ -2,21 +2,18 @@ package main
import ( import (
"crypto/rand" "crypto/rand"
"encoding/base64"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/ecdh" "github.com/cbeuw/Cloak/internal/ecdh"
) )
func generateUID() string { func generateUID() string {
UID := make([]byte, 16) UID := make([]byte, 16)
common.CryptoRandRead(UID) rand.Read(UID)
return base64.StdEncoding.EncodeToString(UID) return b64(UID)
} }
func generateKeyPair() (string, string) { func generateKeyPair() (string, string) {
staticPv, staticPub, _ := ecdh.GenerateKey(rand.Reader) staticPv, staticPub, _ := ecdh.GenerateKey(rand.Reader)
marshPub := ecdh.Marshal(staticPub) marshPub := ecdh.Marshal(staticPub)
marshPv := staticPv.(*[32]byte)[:] marshPv := staticPv.(*[32]byte)[:]
return base64.StdEncoding.EncodeToString(marshPub), base64.StdEncoding.EncodeToString(marshPv) return b64(marshPub), b64(marshPv)
} }

View File

@ -1,4 +0,0 @@
coverage:
status:
project: off
patch: off

View File

@ -2,8 +2,8 @@
"Transport": "direct", "Transport": "direct",
"ProxyMethod": "shadowsocks", "ProxyMethod": "shadowsocks",
"EncryptionMethod": "plain", "EncryptionMethod": "plain",
"UID": "---Your UID here---", "UID": "5nneblJy6lniPJfr81LuYQ==",
"PublicKey": "---Public key here---", "PublicKey": "IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=",
"ServerName": "www.bing.com", "ServerName": "www.bing.com",
"NumConn": 4, "NumConn": 4,
"BrowserSig": "chrome", "BrowserSig": "chrome",

View File

@ -18,10 +18,11 @@
":80" ":80"
], ],
"BypassUID": [ "BypassUID": [
"---Bypass UID here---" "1rmq6Ag1jZJCImLBIL5wzQ=="
], ],
"RedirAddr": "cloudflare.com", "RedirAddr": "204.79.197.200",
"PrivateKey": "---Private key here---", "PrivateKey": "EN5aPEpNBO+vw+BtFQY2OnK9bQU7rvEj5qmnmgwEtUc=",
"AdminUID": "---Admin UID here (optional)---", "AdminUID": "5nneblJy6lniPJfr81LuYQ==",
"DatabasePath": "userinfo.db" "DatabasePath": "userinfo.db",
"StreamTimeout": 300
} }

36
go.mod
View File

@ -1,30 +1,16 @@
module github.com/cbeuw/Cloak module github.com/cbeuw/Cloak
go 1.24.0 go 1.12
toolchain go1.24.2
require ( require (
github.com/cbeuw/connutil v0.0.0-20200411215123-966bfaa51ee3 github.com/Yawning/chacha20 v0.0.0-20170904085104-e3b1f968fc63 // indirect
github.com/gorilla/mux v1.8.1 github.com/boltdb/bolt v1.3.1
github.com/gorilla/websocket v1.5.3 github.com/gorilla/mux v1.7.3
github.com/juju/ratelimit v1.0.2 github.com/gorilla/websocket v1.4.1
github.com/refraction-networking/utls v1.8.0 github.com/juju/ratelimit v1.0.1
github.com/sirupsen/logrus v1.9.3 github.com/kr/pretty v0.1.0 // indirect
github.com/stretchr/testify v1.10.0 github.com/refraction-networking/utls v0.0.0-20190824032329-cc2996c81813
go.etcd.io/bbolt v1.4.0 github.com/sirupsen/logrus v1.4.2
golang.org/x/crypto v0.37.0 golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4
) gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
require (
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/cloudflare/circl v1.6.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
golang.org/x/sys v0.32.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

61
go.sum
View File

@ -1,61 +0,0 @@
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/cbeuw/connutil v0.0.0-20200411215123-966bfaa51ee3 h1:LRxW8pdmWmyhoNh+TxUjxsAinGtCsVGjsl3xg6zoRSs=
github.com/cbeuw/connutil v0.0.0-20200411215123-966bfaa51ee3/go.mod h1:6jR2SzckGv8hIIS9zWJ160mzGVVOYp4AXZMDtacL6LE=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/juju/ratelimit v1.0.2 h1:sRxmtRiajbvrcLQT7S+JbqU0ntsb9W2yhSdNN8tWfaI=
github.com/juju/ratelimit v1.0.2/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.6.6 h1:igFsYBUJPYM8Rno9xUuDoM5GQrVEqY4llzEXOkL43Ig=
github.com/refraction-networking/utls v1.6.6/go.mod h1:BC3O4vQzye5hqpmDTWUqi4P5DDhzJfkV1tdqtawQIH0=
github.com/refraction-networking/utls v1.7.0/go.mod h1:lV0Gwc1/Fi+HYH8hOtgFRdHfKo4FKSn6+FdyOz9hRms=
github.com/refraction-networking/utls v1.7.3 h1:L0WRhHY7Oq1T0zkdzVZMR6zWZv+sXbHB9zcuvsAEqCo=
github.com/refraction-networking/utls v1.7.3/go.mod h1:TUhh27RHMGtQvjQq+RyO11P6ZNQNBb3N0v7wsEjKAIQ=
github.com/refraction-networking/utls v1.8.0 h1:L38krhiTAyj9EeiQQa2sg+hYb4qwLCqdMcpZrRfbONE=
github.com/refraction-networking/utls v1.8.0/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
go.etcd.io/bbolt v1.4.0 h1:TU77id3TnN/zKr7CO/uk+fBCwF2jGcMuw2B/FMAzYIk=
go.etcd.io/bbolt v1.4.0/go.mod h1:AsD+OCi/qPN1giOX1aiLAha3o1U8rAz65bvN4j0sRuk=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,168 +1,85 @@
package client package client
import ( import (
"github.com/cbeuw/Cloak/internal/common" "encoding/binary"
utls "github.com/refraction-networking/utls" "github.com/cbeuw/Cloak/internal/util"
log "github.com/sirupsen/logrus"
"net" "net"
"strings"
log "github.com/sirupsen/logrus"
) )
const appDataMaxLength = 16401 type browser interface {
composeClientHello(chHiddenData) []byte
type clientHelloFields struct {
random []byte
sessionId []byte
x25519KeyShare []byte
serverName string
} }
type browser int func makeServerName(serverName string) []byte {
serverNameListLength := make([]byte, 2)
binary.BigEndian.PutUint16(serverNameListLength, uint16(len(serverName)+3))
serverNameType := []byte{0x00} // host_name
serverNameLength := make([]byte, 2)
binary.BigEndian.PutUint16(serverNameLength, uint16(len(serverName)))
ret := make([]byte, 2+1+2+len(serverName))
copy(ret[0:2], serverNameListLength)
copy(ret[2:3], serverNameType)
copy(ret[3:5], serverNameLength)
copy(ret[5:], serverName)
return ret
}
const ( // addExtensionRecord, add type, length to extension data
chrome = iota func addExtRec(typ []byte, data []byte) []byte {
firefox length := make([]byte, 2)
safari binary.BigEndian.PutUint16(length, uint16(len(data)))
) ret := make([]byte, 2+2+len(data))
copy(ret[0:2], typ)
copy(ret[2:4], length)
copy(ret[4:], data)
return ret
}
type DirectTLS struct { type DirectTLS struct {
*common.TLSConn Transport
browser browser
} }
var topLevelDomains = []string{"com", "net", "org", "it", "fr", "me", "ru", "cn", "es", "tr", "top", "xyz", "info"} func (DirectTLS) HasRecordLayer() bool { return true }
func (DirectTLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS }
func randomServerName() string { // PrepareConnection handles the TLS handshake for a given conn and returns the sessionKey
/*
Copyright: Proton AG
https://github.com/ProtonVPN/wireguard-go/commit/bcf344b39b213c1f32147851af0d2a8da9266883
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
charNum := int('z') - int('a') + 1
size := 3 + common.RandInt(10)
name := make([]byte, size)
for i := range name {
name[i] = byte(int('a') + common.RandInt(charNum))
}
return string(name) + "." + common.RandItem(topLevelDomains)
}
func buildClientHello(browser browser, fields clientHelloFields) ([]byte, error) {
// We don't use utls to handle connections (as it'll attempt a real TLS negotiation)
// We only want it to build the ClientHello locally
fakeConn := net.TCPConn{}
var helloID utls.ClientHelloID
switch browser {
case chrome:
helloID = utls.HelloChrome_Auto
case firefox:
helloID = utls.HelloFirefox_Auto
case safari:
helloID = utls.HelloSafari_Auto
}
uclient := utls.UClient(&fakeConn, &utls.Config{ServerName: fields.serverName}, helloID)
if err := uclient.BuildHandshakeState(); err != nil {
return []byte{}, err
}
if err := uclient.SetClientRandom(fields.random); err != nil {
return []byte{}, err
}
uclient.HandshakeState.Hello.SessionId = make([]byte, 32)
copy(uclient.HandshakeState.Hello.SessionId, fields.sessionId)
// Find the X25519 key share and overwrite it
var extIndex int
var keyShareIndex int
for i, ext := range uclient.Extensions {
ext, ok := ext.(*utls.KeyShareExtension)
if ok {
extIndex = i
for j, keyShare := range ext.KeyShares {
if keyShare.Group == utls.X25519 {
keyShareIndex = j
}
}
}
}
copy(uclient.Extensions[extIndex].(*utls.KeyShareExtension).KeyShares[keyShareIndex].Data, fields.x25519KeyShare)
if err := uclient.BuildHandshakeState(); err != nil {
return []byte{}, err
}
return uclient.HandshakeState.Hello.Raw, nil
}
// Handshake handles the TLS handshake for a given conn and returns the sessionKey
// if the server proceed with Cloak authentication // if the server proceed with Cloak authentication
func (tls *DirectTLS) Handshake(rawConn net.Conn, authInfo AuthInfo) (sessionKey [32]byte, err error) { func (DirectTLS) PrepareConnection(sta *State, conn net.Conn) (preparedConn net.Conn, sessionKey []byte, err error) {
payload, sharedSecret := makeAuthenticationPayload(authInfo) preparedConn = conn
hd, sharedSecret := makeHiddenData(sta)
fields := clientHelloFields{ chOnly := sta.browser.composeClientHello(hd)
random: payload.randPubKey[:], chWithRecordLayer := util.AddRecordLayer(chOnly, []byte{0x16}, []byte{0x03, 0x01})
sessionId: payload.ciphertextWithTag[0:32], _, err = preparedConn.Write(chWithRecordLayer)
x25519KeyShare: payload.ciphertextWithTag[32:64],
serverName: authInfo.MockDomain,
}
if strings.EqualFold(fields.serverName, "random") {
fields.serverName = randomServerName()
}
var ch []byte
ch, err = buildClientHello(tls.browser, fields)
if err != nil {
return
}
chWithRecordLayer := common.AddRecordLayer(ch, common.Handshake, common.VersionTLS11)
_, err = rawConn.Write(chWithRecordLayer)
if err != nil { if err != nil {
return return
} }
log.Trace("client hello sent successfully") log.Trace("client hello sent successfully")
tls.TLSConn = common.NewTLSConn(rawConn)
buf := make([]byte, 1024) buf := make([]byte, 1024)
log.Trace("waiting for ServerHello") log.Trace("waiting for ServerHello")
_, err = tls.Read(buf) _, err = util.ReadTLS(preparedConn, buf)
if err != nil { if err != nil {
return return
} }
encrypted := append(buf[6:38], buf[84:116]...) encrypted := append(buf[11:43], buf[89:121]...)
nonce := encrypted[0:12] nonce := encrypted[0:12]
ciphertextWithTag := encrypted[12:60] ciphertextWithTag := encrypted[12:60]
sessionKeySlice, err := common.AESGCMDecrypt(nonce, sharedSecret[:], ciphertextWithTag) sessionKey, err = util.AESGCMDecrypt(nonce, sharedSecret, ciphertextWithTag)
if err != nil { if err != nil {
return return
} }
copy(sessionKey[:], sessionKeySlice)
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
// ChangeCipherSpec and EncryptedCert (in the format of application data) // ChangeCipherSpec and EncryptedCert (in the format of application data)
_, err = tls.Read(buf) _, err = util.ReadTLS(preparedConn, buf)
if err != nil { if err != nil {
return return
} }
} }
return sessionKey, nil
return preparedConn, sessionKey, nil
} }

View File

@ -0,0 +1,43 @@
package client
import (
"bytes"
"encoding/hex"
"testing"
)
func htob(s string) []byte {
b, _ := hex.DecodeString(s)
return b
}
func TestMakeServerName(t *testing.T) {
type testingPair struct {
serverName string
target []byte
}
pairs := []testingPair{
{
"www.google.com",
htob("001100000e7777772e676f6f676c652e636f6d"),
},
{
"www.gstatic.com",
htob("001200000f7777772e677374617469632e636f6d"),
},
{
"googleads.g.doubleclick.net",
htob("001e00001b676f6f676c656164732e672e646f75626c65636c69636b2e6e6574"),
},
}
for _, p := range pairs {
if !bytes.Equal(makeServerName(p.serverName), p.target) {
t.Error(
"for", p.serverName,
"expecting", p.target,
"got", makeServerName(p.serverName))
}
}
}

View File

@ -1,25 +1,29 @@
package client package client
import ( import (
"crypto/rand"
"encoding/binary" "encoding/binary"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/ecdh" "github.com/cbeuw/Cloak/internal/ecdh"
log "github.com/sirupsen/logrus" "github.com/cbeuw/Cloak/internal/util"
"sync/atomic"
) )
const ( const (
UNORDERED_FLAG = 0x01 // 0000 0001 UNORDERED_FLAG = 0x01 // 0000 0001
) )
type authenticationPayload struct { type chHiddenData struct {
randPubKey [32]byte fullRaw []byte // pubkey, ciphertext, tag
ciphertextWithTag [64]byte chRandom []byte
chSessionId []byte
chX25519KeyShare []byte
chExtSNI []byte
} }
// makeAuthenticationPayload generates the ephemeral key pair, calculates the shared secret, and then compose and // makeHiddenData generates the ephemeral key pair, calculates the shared secret, and then compose and
// encrypt the authenticationPayload // encrypt the Authentication data. It also composes SNI extension.
func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sharedSecret [32]byte) { func makeHiddenData(sta *State) (ret chHiddenData, sharedSecret []byte) {
// random is marshalled ephemeral pub key 32 bytes
/* /*
Authentication data: Authentication data:
+----------+----------------+---------------------+-------------+--------------+--------+------------+ +----------+----------------+---------------------+-------------+--------------+--------+------------+
@ -28,29 +32,27 @@ func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sh
| 16 bytes | 12 bytes | 1 byte | 8 bytes | 4 bytes | 1 byte | 6 bytes | | 16 bytes | 12 bytes | 1 byte | 8 bytes | 4 bytes | 1 byte | 6 bytes |
+----------+----------------+---------------------+-------------+--------------+--------+------------+ +----------+----------------+---------------------+-------------+--------------+--------+------------+
*/ */
ephPv, ephPub, err := ecdh.GenerateKey(authInfo.WorldState.Rand) // The authentication ciphertext and its tag are then distributed among SessionId and X25519KeyShare
if err != nil { ephPv, ephPub, _ := ecdh.GenerateKey(rand.Reader)
log.Panicf("failed to generate ephemeral key pair: %v", err) ret.chRandom = ecdh.Marshal(ephPub)
}
copy(ret.randPubKey[:], ecdh.Marshal(ephPub))
plaintext := make([]byte, 48) plaintext := make([]byte, 48)
copy(plaintext, authInfo.UID) copy(plaintext, sta.UID)
copy(plaintext[16:28], authInfo.ProxyMethod) copy(plaintext[16:28], sta.ProxyMethod)
plaintext[28] = authInfo.EncryptionMethod plaintext[28] = sta.EncryptionMethod
binary.BigEndian.PutUint64(plaintext[29:37], uint64(authInfo.WorldState.Now().UTC().Unix())) binary.BigEndian.PutUint64(plaintext[29:37], uint64(sta.Now().Unix()))
binary.BigEndian.PutUint32(plaintext[37:41], authInfo.SessionId) binary.BigEndian.PutUint32(plaintext[37:41], atomic.LoadUint32(&sta.SessionID))
if authInfo.Unordered { if sta.Unordered {
plaintext[41] |= UNORDERED_FLAG plaintext[41] |= UNORDERED_FLAG
} }
secret, err := ecdh.GenerateSharedSecret(ephPv, authInfo.ServerPubKey) sharedSecret = ecdh.GenerateSharedSecret(ephPv, sta.staticPub)
if err != nil { nonce := ret.chRandom[0:12]
log.Panicf("error in generating shared secret: %v", err) ciphertextWithTag, _ := util.AESGCMEncrypt(nonce, sharedSecret, plaintext)
} ret.fullRaw = append(ret.chRandom, ciphertextWithTag...)
copy(sharedSecret[:], secret) ret.chSessionId = ciphertextWithTag[0:32]
ciphertextWithTag, _ := common.AESGCMEncrypt(ret.randPubKey[:12], sharedSecret[:], plaintext) ret.chX25519KeyShare = ciphertextWithTag[32:64]
copy(ret.ciphertextWithTag[:], ciphertextWithTag[:]) ret.chExtSNI = makeServerName(sta.ServerName)
return return
} }

View File

@ -1,73 +0,0 @@
package client
import (
"bytes"
"testing"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/multiplex"
"github.com/stretchr/testify/assert"
)
func TestMakeAuthenticationPayload(t *testing.T) {
tests := []struct {
authInfo AuthInfo
expPayload authenticationPayload
expSecret [32]byte
}{
{
AuthInfo{
Unordered: false,
SessionId: 3421516597,
UID: []byte{
0x4c, 0xd8, 0xcc, 0x15, 0x60, 0x0d, 0x7e,
0xb6, 0x81, 0x31, 0xfd, 0x80, 0x97, 0x67, 0x37, 0x46},
ServerPubKey: &[32]byte{
0x21, 0x8a, 0x14, 0xce, 0x49, 0x5e, 0xfd, 0x3f,
0xe4, 0xae, 0x21, 0x3e, 0x51, 0xf7, 0x66, 0xec,
0x01, 0xd0, 0xb4, 0x87, 0x86, 0x9c, 0x15, 0x9b,
0x86, 0x19, 0x53, 0x6e, 0x60, 0xe9, 0x51, 0x42},
ProxyMethod: "shadowsocks",
EncryptionMethod: multiplex.EncryptionMethodPlain,
MockDomain: "d2jkinvisak5y9.cloudfront.net",
WorldState: common.WorldState{
Rand: bytes.NewBuffer([]byte{
0xf1, 0x1e, 0x42, 0xe1, 0x84, 0x22, 0x07, 0xc5,
0xc3, 0x5c, 0x0f, 0x7b, 0x01, 0xf3, 0x65, 0x2d,
0xd7, 0x9b, 0xad, 0xb0, 0xb2, 0x77, 0xa2, 0x06,
0x6b, 0x78, 0x1b, 0x74, 0x1f, 0x43, 0xc9, 0x80}),
Now: func() time.Time { return time.Unix(1579908372, 0) },
},
},
authenticationPayload{
randPubKey: [32]byte{
0xee, 0x9e, 0x41, 0x4e, 0xb3, 0x3b, 0x85, 0x03,
0x6d, 0x85, 0xba, 0x30, 0x11, 0x31, 0x10, 0x24,
0x4f, 0x7b, 0xd5, 0x38, 0x50, 0x0f, 0xf2, 0x4d,
0xa3, 0xdf, 0xba, 0x76, 0x0a, 0xe9, 0x19, 0x19},
ciphertextWithTag: [64]byte{
0x71, 0xb1, 0x6c, 0x5a, 0x60, 0x46, 0x90, 0x12,
0x36, 0x3b, 0x1b, 0xc4, 0x79, 0x3c, 0xab, 0xdd,
0x5a, 0x53, 0xc5, 0xed, 0xaf, 0xdb, 0x10, 0x98,
0x83, 0x96, 0x81, 0xa6, 0xfc, 0xa2, 0x1e, 0xb0,
0x89, 0xb2, 0x29, 0x71, 0x7e, 0x45, 0x97, 0x54,
0x11, 0x7d, 0x9b, 0x92, 0xbb, 0xd6, 0xce, 0x37,
0x3b, 0xb8, 0x8b, 0xfb, 0xb6, 0x40, 0xf0, 0x2c,
0x6c, 0x55, 0xb9, 0xfc, 0x5d, 0x34, 0x89, 0x41},
},
[32]byte{
0xc7, 0xc6, 0x9b, 0xbe, 0xec, 0xf8, 0x35, 0x55,
0x67, 0x20, 0xcd, 0xeb, 0x74, 0x16, 0xc5, 0x60,
0xee, 0x9d, 0x63, 0x1a, 0x44, 0xc5, 0x09, 0xf6,
0xe0, 0x24, 0xad, 0xd2, 0x10, 0xe3, 0x4a, 0x11},
},
}
for _, tc := range tests {
func() {
payload, sharedSecret := makeAuthenticationPayload(tc.authInfo)
assert.Equal(t, tc.expPayload, payload, "payload doesn't match")
assert.Equal(t, tc.expSecret, sharedSecret, "shared secret doesn't match")
}()
}
}

103
internal/client/chrome.go Normal file
View File

@ -0,0 +1,103 @@
// Fingerprint of Chrome 76
package client
import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
)
type Chrome struct{}
func makeGREASE() []byte {
// see https://tools.ietf.org/html/draft-davidben-tls-grease-01
// This is exclusive to Chrome.
var one [1]byte
rand.Read(one[:])
sixteenth := one[0] % 16
monoGREASE := byte(sixteenth*16 + 0xA)
doubleGREASE := []byte{monoGREASE, monoGREASE}
return doubleGREASE
}
func (c *Chrome) composeExtensions(sni []byte, keyShare []byte) []byte {
makeSupportedGroups := func() []byte {
suppGroupListLen := []byte{0x00, 0x08}
ret := make([]byte, 2+8)
copy(ret[0:2], suppGroupListLen)
copy(ret[2:4], makeGREASE())
copy(ret[4:], []byte{0x00, 0x1d, 0x00, 0x17, 0x00, 0x18})
return ret
}
makeKeyShare := func(hidden []byte) []byte {
ret := make([]byte, 43)
ret[0], ret[1] = 0x00, 0x29 // length 41
copy(ret[2:4], makeGREASE())
ret[4], ret[5] = 0x00, 0x01 // length 1
ret[6] = 0x00
ret[7], ret[8] = 0x00, 0x1d // group x25519
ret[9], ret[10] = 0x00, 0x20 // length 32
copy(ret[11:43], hidden)
return ret
}
// extension length is always 401, and server name length is variable
var ext [17][]byte
ext[0] = addExtRec(makeGREASE(), nil) // First GREASE
ext[1] = addExtRec([]byte{0x00, 0x00}, sni) // server name indication
ext[2] = addExtRec([]byte{0x00, 0x17}, nil) // extended_master_secret
ext[3] = addExtRec([]byte{0xff, 0x01}, []byte{0x00}) // renegotiation_info
ext[4] = addExtRec([]byte{0x00, 0x0a}, makeSupportedGroups()) // supported groups
ext[5] = addExtRec([]byte{0x00, 0x0b}, []byte{0x01, 0x00}) // ec point formats
ext[6] = addExtRec([]byte{0x00, 0x23}, nil) // Session tickets
APLN, _ := hex.DecodeString("000c02683208687474702f312e31")
ext[7] = addExtRec([]byte{0x00, 0x10}, APLN) // app layer proto negotiation
ext[8] = addExtRec([]byte{0x00, 0x05}, []byte{0x01, 0x00, 0x00, 0x00, 0x00}) // status request
sigAlgo, _ := hex.DecodeString("0012040308040401050308050501080606010201")
ext[9] = addExtRec([]byte{0x00, 0x0d}, sigAlgo) // Signature Algorithms
ext[10] = addExtRec([]byte{0x00, 0x12}, nil) // signed cert timestamp
ext[11] = addExtRec([]byte{0x00, 0x33}, makeKeyShare(keyShare)) // key share
ext[12] = addExtRec([]byte{0x00, 0x2d}, []byte{0x01, 0x01}) // psk key exchange modes
suppVersions, _ := hex.DecodeString("0a9A9A0304030303020301") // 9A9A needs to be a GREASE
copy(suppVersions[1:3], makeGREASE())
ext[13] = addExtRec([]byte{0x00, 0x2b}, suppVersions) // supported versions
ext[14] = addExtRec([]byte{0x00, 0x1b}, []byte{0x02, 0x00, 0x02})
ext[15] = addExtRec(makeGREASE(), []byte{0x00}) // Last GREASE
// len(ext[1]) + 172 + len(ext[16]) = 401
// len(ext[16]) = 229 - len(ext[1])
// 2+2+len(padding) = 229 - len(ext[1])
// len(padding) = 225 - len(ext[1])
ext[16] = addExtRec([]byte{0x00, 0x15}, make([]byte, 225-len(ext[1]))) // padding
var ret []byte
for _, e := range ext {
ret = append(ret, e...)
}
return ret
}
func (c *Chrome) composeClientHello(hd chHiddenData) (ch []byte) {
var clientHello [12][]byte
clientHello[0] = []byte{0x01} // handshake type
clientHello[1] = []byte{0x00, 0x01, 0xfc} // length 508
clientHello[2] = []byte{0x03, 0x03} // client version
clientHello[3] = hd.chRandom // random
clientHello[4] = []byte{0x20} // session id length 32
clientHello[5] = hd.chSessionId // session id
clientHello[6] = []byte{0x00, 0x22} // cipher suites length 34
cipherSuites, _ := hex.DecodeString("130113021303c02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a")
clientHello[7] = append(makeGREASE(), cipherSuites...) // cipher suites
clientHello[8] = []byte{0x01} // compression methods length 1
clientHello[9] = []byte{0x00} // compression methods
clientHello[11] = c.composeExtensions(hd.chExtSNI, hd.chX25519KeyShare)
clientHello[10] = []byte{0x00, 0x00} // extensions length 401
binary.BigEndian.PutUint16(clientHello[10], uint16(len(clientHello[11])))
var ret []byte
for _, c := range clientHello {
ret = append(ret, c...)
}
return ret
}

View File

@ -0,0 +1,48 @@
package client
import (
"encoding/hex"
"testing"
)
func TestMakeGREASE(t *testing.T) {
a := hex.EncodeToString(makeGREASE())
if a[1] != 'a' || a[3] != 'a' {
t.Errorf("GREASE got %v", a)
}
var GREASEs []string
for i := 0; i < 50; i++ {
GREASEs = append(GREASEs, hex.EncodeToString(makeGREASE()))
}
var eqCount int
for _, g := range GREASEs {
if a == g {
eqCount++
}
}
if eqCount > 40 {
t.Error("GREASE is not random", GREASEs)
}
}
func TestComposeExtension(t *testing.T) {
serverName := "cdn.bizible.com"
keyShare, _ := hex.DecodeString("010a8896b68fb16e2a245ed87be2699348ab72068bb326eac5beaa00fa56ff17")
sni := makeServerName(serverName)
result := (&Chrome{}).composeExtensions(sni, keyShare)
target, _ := hex.DecodeString("5a5a000000000014001200000f63646e2e62697a69626c652e636f6d00170000ff01000100000a000a0008fafa001d00170018000b00020100002300000010000e000c02683208687474702f312e31000500050100000000000d00140012040308040401050308050501080606010201001200000033002b0029fafa000100001d0020010a8896b68fb16e2a245ed87be2699348ab72068bb326eac5beaa00fa56ff17002d00020101002b000b0aaaaa0304030303020301001b0003020002eaea000100001500c9000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
for p := 0; p < len(result); {
// skip GREASEs
if p == 0 || p == 43 || p == 122 || p == 174 || p == 191 {
p += 2
continue
}
if result[p] != target[p] {
t.Errorf("inequality at %v", p)
}
p += 1
}
}

View File

@ -1,83 +0,0 @@
package client
import (
"net"
"sync"
"sync/atomic"
"time"
"github.com/cbeuw/Cloak/internal/common"
mux "github.com/cbeuw/Cloak/internal/multiplex"
log "github.com/sirupsen/logrus"
)
// On different invocations to MakeSession, authInfo.SessionId MUST be different
func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.Dialer) *mux.Session {
log.Info("Attempting to start a new session")
connsCh := make(chan net.Conn, connConfig.NumConn)
var _sessionKey atomic.Value
var wg sync.WaitGroup
for i := 0; i < connConfig.NumConn; i++ {
wg.Add(1)
transportConfig := connConfig.Transport
go func() {
makeconn:
transportConn := transportConfig.CreateTransport()
remoteConn, err := dialer.Dial("tcp", connConfig.RemoteAddr)
if err != nil {
log.Errorf("Failed to establish new connections to remote: %v", err)
// TODO increase the interval if failed multiple times
time.Sleep(time.Second * 3)
goto makeconn
}
sk, err := transportConn.Handshake(remoteConn, authInfo)
if err != nil {
log.Errorf("Failed to prepare connection to remote: %v", err)
transportConn.Close()
// In Cloak v2.11.0, we've updated uTLS version and subsequently increased the first packet size for chrome above 1500
// https://github.com/cbeuw/Cloak/pull/306#issuecomment-2862728738. As a backwards compatibility feature, if we fail
// to connect using chrome signature, retry with firefox which has a smaller packet size.
if transportConfig.mode == "direct" && transportConfig.browser == chrome {
transportConfig.browser = firefox
log.Warnf("failed to connect with chrome signature, falling back to retry with firefox")
}
time.Sleep(time.Second * 3)
goto makeconn
}
// sessionKey given by each connection should be identical
_sessionKey.Store(sk)
connsCh <- transportConn
wg.Done()
}()
}
wg.Wait()
log.Debug("All underlying connections established")
sessionKey := _sessionKey.Load().([32]byte)
obfuscator, err := mux.MakeObfuscator(authInfo.EncryptionMethod, sessionKey)
if err != nil {
log.Fatal(err)
}
seshConfig := mux.SessionConfig{
Singleplex: connConfig.Singleplex,
Obfuscator: obfuscator,
Valve: nil,
Unordered: authInfo.Unordered,
MsgOnWireSizeLimit: appDataMaxLength,
}
sesh := mux.MakeSession(authInfo.SessionId, seshConfig)
for i := 0; i < connConfig.NumConn; i++ {
conn := <-connsCh
sesh.AddConnection(conn)
}
log.Infof("Session %v established", authInfo.SessionId)
return sesh
}

View File

@ -0,0 +1,77 @@
// Fingerprint of Firefox 68
package client
import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
)
type Firefox struct{}
func (f *Firefox) composeExtensions(SNI []byte, keyShare []byte) []byte {
composeKeyShare := func(hidden []byte) []byte {
ret := make([]byte, 107)
ret[0], ret[1] = 0x00, 0x69 // length 105
ret[2], ret[3] = 0x00, 0x1d // group x25519
ret[4], ret[5] = 0x00, 0x20 // length 32
copy(ret[6:38], hidden)
ret[38], ret[39] = 0x00, 0x17 // group secp256r1
ret[40], ret[41] = 0x00, 0x41 // length 65
rand.Read(ret[42:107])
return ret
}
// extension length is always 399, and server name length is variable
var ext [14][]byte
ext[0] = addExtRec([]byte{0x00, 0x00}, SNI) // server name indication
ext[1] = addExtRec([]byte{0x00, 0x17}, nil) // extended_master_secret
ext[2] = addExtRec([]byte{0xff, 0x01}, []byte{0x00}) // renegotiation_info
suppGroup, _ := hex.DecodeString("000c001d00170018001901000101")
ext[3] = addExtRec([]byte{0x00, 0x0a}, suppGroup) // supported groups
ext[4] = addExtRec([]byte{0x00, 0x0b}, []byte{0x01, 0x00}) // ec point formats
ext[5] = addExtRec([]byte{0x00, 0x23}, []byte{}) // Session tickets
APLN, _ := hex.DecodeString("000c02683208687474702f312e31")
ext[6] = addExtRec([]byte{0x00, 0x10}, APLN) // app layer proto negotiation
ext[7] = addExtRec([]byte{0x00, 0x05}, []byte{0x01, 0x00, 0x00, 0x00, 0x00}) // status request
ext[8] = addExtRec([]byte{0x00, 0x33}, composeKeyShare(keyShare)) // key share
suppVersions, _ := hex.DecodeString("080304030303020301")
ext[9] = addExtRec([]byte{0x00, 0x2b}, suppVersions) // supported versions
sigAlgo, _ := hex.DecodeString("001604030503060308040805080604010501060102030201")
ext[10] = addExtRec([]byte{0x00, 0x0d}, sigAlgo) // Signature Algorithms
ext[11] = addExtRec([]byte{0x00, 0x2d}, []byte{0x01, 0x01}) // psk key exchange modes
ext[12] = addExtRec([]byte{0x00, 0x1c}, []byte{0x40, 0x01}) // record size limit
// len(ext[0]) + 237 + 4 + len(padding) = 399
// len(padding) = 158 - len(ext[0])
ext[13] = addExtRec([]byte{0x00, 0x15}, make([]byte, 163-len(SNI))) // padding
var ret []byte
for _, e := range ext {
ret = append(ret, e...)
}
return ret
}
func (f *Firefox) composeClientHello(hd chHiddenData) (ch []byte) {
var clientHello [12][]byte
clientHello[0] = []byte{0x01} // handshake type
clientHello[1] = []byte{0x00, 0x01, 0xfc} // length 508
clientHello[2] = []byte{0x03, 0x03} // client version
clientHello[3] = hd.chRandom // random
clientHello[4] = []byte{0x20} // session id length 32
clientHello[5] = hd.chSessionId // session id
clientHello[6] = []byte{0x00, 0x24} // cipher suites length 36
cipherSuites, _ := hex.DecodeString("130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a")
clientHello[7] = cipherSuites // cipher suites
clientHello[8] = []byte{0x01} // compression methods length 1
clientHello[9] = []byte{0x00} // compression methods
clientHello[11] = f.composeExtensions(hd.chExtSNI, hd.chX25519KeyShare)
clientHello[10] = []byte{0x00, 0x00} // extensions length
binary.BigEndian.PutUint16(clientHello[10], uint16(len(clientHello[11])))
var ret []byte
for _, c := range clientHello {
ret = append(ret, c...)
}
return ret
}

View File

@ -0,0 +1,20 @@
package client
import (
"bytes"
"encoding/hex"
"testing"
)
func TestComposeExtensions(t *testing.T) {
target, _ := hex.DecodeString("000000170015000012636f6e73656e742e676f6f676c652e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00206075db0a43812b2e4e0f44157f04295b484ccfc6d70e577c1e6113aa18e088270017004104948052ae52043e654641660ebbadb527c8280262e61f64b0f6f1794f32e1000865a49e4cbe2027c78e7180861e4336300815fa0f1b0091c4d788b97f809a47d3002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c000240010015008c0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
serverName := "consent.google.com"
keyShare, _ := hex.DecodeString("6075db0a43812b2e4e0f44157f04295b484ccfc6d70e577c1e6113aa18e08827")
sni := makeServerName(serverName)
result := (&Firefox{}).composeExtensions(sni, keyShare)
// skip random secp256r1
if !bytes.Equal(result[:137], target[:137]) || !bytes.Equal(result[202:], target[202:]) {
t.Errorf("got %x", result)
}
}

View File

@ -1,153 +0,0 @@
package client
import (
"io"
"net"
"sync"
"time"
"github.com/cbeuw/Cloak/internal/common"
mux "github.com/cbeuw/Cloak/internal/multiplex"
log "github.com/sirupsen/logrus"
)
func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) {
var sesh *mux.Session
localConn, err := bindFunc()
if err != nil {
log.Fatal(err)
}
streams := make(map[string]*mux.Stream)
var streamsMutex sync.Mutex
data := make([]byte, 8192)
for {
i, addr, err := localConn.ReadFrom(data)
if err != nil {
log.Errorf("Failed to read first packet from proxy client: %v", err)
continue
}
if !singleplex && (sesh == nil || sesh.IsClosed()) {
sesh = newSeshFunc()
}
streamsMutex.Lock()
stream, ok := streams[addr.String()]
if !ok {
if singleplex {
sesh = newSeshFunc()
}
stream, err = sesh.OpenStream()
if err != nil {
if singleplex {
sesh.Close()
}
log.Errorf("Failed to open stream: %v", err)
streamsMutex.Unlock()
continue
}
streams[addr.String()] = stream
streamsMutex.Unlock()
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
proxyAddr := addr
go func(stream *mux.Stream, localConn *net.UDPConn) {
buf := make([]byte, 8192)
for {
n, err := stream.Read(buf)
if err != nil {
log.Tracef("copying stream to proxy client: %v", err)
break
}
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
_, err = localConn.WriteTo(buf[:n], proxyAddr)
if err != nil {
log.Tracef("copying stream to proxy client: %v", err)
break
}
}
streamsMutex.Lock()
delete(streams, addr.String())
streamsMutex.Unlock()
stream.Close()
return
}(stream, localConn)
} else {
streamsMutex.Unlock()
}
_, err = stream.Write(data[:i])
if err != nil {
log.Tracef("copying proxy client to stream: %v", err)
streamsMutex.Lock()
delete(streams, addr.String())
streamsMutex.Unlock()
stream.Close()
continue
}
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
}
}
func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) {
var sesh *mux.Session
for {
localConn, err := listener.Accept()
if err != nil {
log.Fatal(err)
continue
}
if !singleplex && (sesh == nil || sesh.IsClosed()) {
sesh = newSeshFunc()
}
go func(sesh *mux.Session, localConn net.Conn, timeout time.Duration) {
if singleplex {
sesh = newSeshFunc()
}
data := make([]byte, 10240)
_ = localConn.SetReadDeadline(time.Now().Add(streamTimeout))
i, err := io.ReadAtLeast(localConn, data, 1)
if err != nil {
log.Errorf("Failed to read first packet from proxy client: %v", err)
localConn.Close()
return
}
var zeroTime time.Time
_ = localConn.SetReadDeadline(zeroTime)
stream, err := sesh.OpenStream()
if err != nil {
log.Errorf("Failed to open stream: %v", err)
localConn.Close()
if singleplex {
sesh.Close()
}
return
}
_, err = stream.Write(data[:i])
if err != nil {
log.Errorf("Failed to write to stream: %v", err)
localConn.Close()
stream.Close()
return
}
go func() {
if _, err := common.Copy(localConn, stream); err != nil {
log.Tracef("copying stream to proxy client: %v", err)
}
}()
if _, err = common.Copy(stream, localConn); err != nil {
log.Tracef("copying proxy client to stream: %v", err)
}
}(sesh, localConn, streamTimeout)
}
}

View File

@ -2,88 +2,62 @@ package client
import ( import (
"crypto" "crypto"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "errors"
"io/ioutil" "io/ioutil"
"net"
"strings" "strings"
"time" "time"
"github.com/cbeuw/Cloak/internal/common"
log "github.com/sirupsen/logrus"
"github.com/cbeuw/Cloak/internal/ecdh" "github.com/cbeuw/Cloak/internal/ecdh"
mux "github.com/cbeuw/Cloak/internal/multiplex" mux "github.com/cbeuw/Cloak/internal/multiplex"
) )
// RawConfig represents the fields in the config json file // rawConfig represents the fields in the config json file
// nullable means if it's empty, a default value will be chosen in ProcessRawConfig type rawConfig struct {
// jsonOptional means if the json's empty, its value will be set from environment variables or commandline args
// but it mustn't be empty when ProcessRawConfig is called
type RawConfig struct {
ServerName string ServerName string
ProxyMethod string ProxyMethod string
EncryptionMethod string EncryptionMethod string
UID []byte UID string
PublicKey []byte PublicKey string
BrowserSig string
Transport string
NumConn int NumConn int
LocalHost string // jsonOptional StreamTimeout int
LocalPort string // jsonOptional
RemoteHost string // jsonOptional
RemotePort string // jsonOptional
AlternativeNames []string // jsonOptional
// defaults set in ProcessRawConfig
UDP bool // nullable
BrowserSig string // nullable
Transport string // nullable
CDNOriginHost string // nullable
CDNWsUrlPath string // nullable
StreamTimeout int // nullable
KeepAlive int // nullable
} }
type RemoteConnConfig struct { // State stores the parsed configuration fields
Singleplex bool type State struct {
NumConn int LocalHost string
KeepAlive time.Duration LocalPort string
RemoteAddr string RemoteHost string
Transport TransportConfig RemotePort string
} Unordered bool
type LocalConnConfig struct { Transport Transport
LocalAddr string
Timeout time.Duration
MockDomainList []string
}
type AuthInfo struct { SessionID uint32
UID []byte UID []byte
SessionId uint32
staticPub crypto.PublicKey
Now func() time.Time // for easier testing
browser browser
ProxyMethod string ProxyMethod string
EncryptionMethod byte EncryptionMethod byte
Unordered bool ServerName string
ServerPubKey crypto.PublicKey NumConn int
MockDomain string Timeout time.Duration
WorldState common.WorldState
} }
// semi-colon separated value. This is for Android plugin options // semi-colon separated value. This is for Android plugin options
func ssvToJson(ssv string) (ret []byte) { func ssvToJson(ssv string) (ret []byte) {
elem := func(val string, lst []string) bool {
for _, v := range lst {
if val == v {
return true
}
}
return false
}
unescape := func(s string) string { unescape := func(s string) string {
r := strings.Replace(s, `\\`, `\`, -1) r := strings.Replace(s, `\\`, `\`, -1)
r = strings.Replace(r, `\=`, `=`, -1) r = strings.Replace(r, `\=`, `=`, -1)
r = strings.Replace(r, `\;`, `;`, -1) r = strings.Replace(r, `\;`, `;`, -1)
return r return r
} }
unquoted := []string{"NumConn", "StreamTimeout", "KeepAlive", "UDP"}
lines := strings.Split(unescape(ssv), ";") lines := strings.Split(unescape(ssv), ";")
ret = []byte("{") ret = []byte("{")
for _, ln := range lines { for _, ln := range lines {
@ -91,29 +65,11 @@ func ssvToJson(ssv string) (ret []byte) {
break break
} }
sp := strings.SplitN(ln, "=", 2) sp := strings.SplitN(ln, "=", 2)
if len(sp) < 2 {
log.Errorf("Malformed config option: %v", ln)
continue
}
key := sp[0] key := sp[0]
value := sp[1] value := sp[1]
if strings.HasPrefix(key, "AlternativeNames") {
switch strings.Contains(value, ",") {
case true:
domains := strings.Split(value, ",")
for index, domain := range domains {
domains[index] = `"` + domain + `"`
}
value = strings.Join(domains, ",")
ret = append(ret, []byte(`"`+key+`":[`+value+`],`)...)
case false:
ret = append(ret, []byte(`"`+key+`":["`+value+`"],`)...)
}
continue
}
// JSON doesn't like quotation marks around int and bool // JSON doesn't like quotation marks around int and bool
// This is extremely ugly but it's still better than writing a tokeniser // This is extremely ugly but it's still better than writing a tokeniser
if elem(key, unquoted) { if key == "NumConn" || key == "Unordered" || key == "StreamTimeout" {
ret = append(ret, []byte(`"`+key+`":`+value+`,`)...) ret = append(ret, []byte(`"`+key+`":`+value+`,`)...)
} else { } else {
ret = append(ret, []byte(`"`+key+`":"`+value+`",`)...) ret = append(ret, []byte(`"`+key+`":"`+value+`",`)...)
@ -124,7 +80,8 @@ func ssvToJson(ssv string) (ret []byte) {
return ret return ret
} }
func ParseConfig(conf string) (raw *RawConfig, err error) { // ParseConfig parses the config (either a path to json or Android config) into a State variable
func (sta *State) ParseConfig(conf string) (err error) {
var content []byte var content []byte
// Checking if it's a path to json or a ssv string // Checking if it's a path to json or a ssv string
if strings.Contains(conf, ";") && strings.Contains(conf, "=") { if strings.Contains(conf, ";") && strings.Contains(conf, "=") {
@ -132,148 +89,63 @@ func ParseConfig(conf string) (raw *RawConfig, err error) {
} else { } else {
content, err = ioutil.ReadFile(conf) content, err = ioutil.ReadFile(conf)
if err != nil { if err != nil {
return return err
} }
} }
var preParse rawConfig
raw = new(RawConfig) err = json.Unmarshal(content, &preParse)
err = json.Unmarshal(content, &raw)
if err != nil { if err != nil {
return return err
}
return
}
func (raw *RawConfig) ProcessRawConfig(worldState common.WorldState) (local LocalConnConfig, remote RemoteConnConfig, auth AuthInfo, err error) {
nullErr := func(field string) (local LocalConnConfig, remote RemoteConnConfig, auth AuthInfo, err error) {
err = fmt.Errorf("%v cannot be empty", field)
return
} }
auth.UID = raw.UID switch strings.ToLower(preParse.EncryptionMethod) {
auth.Unordered = raw.UDP
if raw.ServerName == "" {
return nullErr("ServerName")
}
auth.MockDomain = raw.ServerName
var filteredAlternativeNames []string
for _, alternativeName := range raw.AlternativeNames {
if len(alternativeName) > 0 {
filteredAlternativeNames = append(filteredAlternativeNames, alternativeName)
}
}
raw.AlternativeNames = filteredAlternativeNames
local.MockDomainList = raw.AlternativeNames
local.MockDomainList = append(local.MockDomainList, auth.MockDomain)
if raw.ProxyMethod == "" {
return nullErr("ServerName")
}
auth.ProxyMethod = raw.ProxyMethod
if len(raw.UID) == 0 {
return nullErr("UID")
}
// static public key
if len(raw.PublicKey) == 0 {
return nullErr("PublicKey")
}
pub, ok := ecdh.Unmarshal(raw.PublicKey)
if !ok {
err = fmt.Errorf("failed to unmarshal Public key")
return
}
auth.ServerPubKey = pub
auth.WorldState = worldState
// Encryption method
switch strings.ToLower(raw.EncryptionMethod) {
case "plain": case "plain":
auth.EncryptionMethod = mux.EncryptionMethodPlain sta.EncryptionMethod = mux.E_METHOD_PLAIN
case "aes-gcm", "aes-256-gcm": case "aes-gcm":
auth.EncryptionMethod = mux.EncryptionMethodAES256GCM sta.EncryptionMethod = mux.E_METHOD_AES_GCM
case "aes-128-gcm":
auth.EncryptionMethod = mux.EncryptionMethodAES128GCM
case "chacha20-poly1305": case "chacha20-poly1305":
auth.EncryptionMethod = mux.EncryptionMethodChaha20Poly1305 sta.EncryptionMethod = mux.E_METHOD_CHACHA20_POLY1305
default: default:
err = fmt.Errorf("unknown encryption method %v", raw.EncryptionMethod) return errors.New("Unknown encryption method")
return
} }
if raw.RemoteHost == "" { switch strings.ToLower(preParse.BrowserSig) {
return nullErr("RemoteHost")
}
if raw.RemotePort == "" {
return nullErr("RemotePort")
}
remote.RemoteAddr = net.JoinHostPort(raw.RemoteHost, raw.RemotePort)
if raw.NumConn <= 0 {
remote.NumConn = 1
remote.Singleplex = true
} else {
remote.NumConn = raw.NumConn
remote.Singleplex = false
}
// Transport and (if TLS mode), browser
switch strings.ToLower(raw.Transport) {
case "cdn":
var cdnDomainPort string
if raw.CDNOriginHost == "" {
cdnDomainPort = net.JoinHostPort(raw.RemoteHost, raw.RemotePort)
} else {
cdnDomainPort = net.JoinHostPort(raw.CDNOriginHost, raw.RemotePort)
}
if raw.CDNWsUrlPath == "" {
raw.CDNWsUrlPath = "/"
}
remote.Transport = TransportConfig{
mode: "cdn",
wsUrl: "ws://" + cdnDomainPort + raw.CDNWsUrlPath,
}
case "direct":
fallthrough
default:
var browser browser
switch strings.ToLower(raw.BrowserSig) {
case "firefox":
browser = firefox
case "safari":
browser = safari
case "chrome": case "chrome":
fallthrough sta.browser = &Chrome{}
case "firefox":
sta.browser = &Firefox{}
default: default:
browser = chrome return errors.New("unsupported browser signature")
}
remote.Transport = TransportConfig{
mode: "direct",
browser: browser,
}
} }
// KeepAlive switch strings.ToLower(preParse.Transport) {
if raw.KeepAlive <= 0 { case "direct":
remote.KeepAlive = -1 sta.Transport = DirectTLS{}
} else { case "cdn":
remote.KeepAlive = remote.KeepAlive * time.Second sta.Transport = WSOverTLS{}
default:
sta.Transport = DirectTLS{}
} }
if raw.LocalHost == "" { sta.ProxyMethod = preParse.ProxyMethod
return nullErr("LocalHost") sta.ServerName = preParse.ServerName
} sta.NumConn = preParse.NumConn
if raw.LocalPort == "" { sta.Timeout = time.Duration(preParse.StreamTimeout) * time.Second
return nullErr("LocalPort")
}
local.LocalAddr = net.JoinHostPort(raw.LocalHost, raw.LocalPort)
// stream no write timeout
if raw.StreamTimeout == 0 {
local.Timeout = 300 * time.Second
} else {
local.Timeout = time.Duration(raw.StreamTimeout) * time.Second
}
return uid, err := base64.StdEncoding.DecodeString(preParse.UID)
if err != nil {
return errors.New("Failed to parse UID: " + err.Error())
}
sta.UID = uid
pubBytes, err := base64.StdEncoding.DecodeString(preParse.PublicKey)
if err != nil {
return errors.New("Failed to parse Public key: " + err.Error())
}
pub, ok := ecdh.Unmarshal(pubBytes)
if !ok {
return errors.New("Failed to unmarshal Public key")
}
sta.staticPub = pub
return nil
} }

View File

@ -1,37 +1,20 @@
package client package client
import ( import (
"io/ioutil" "bytes"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestParseConfig(t *testing.T) { func TestSSVtoJson(t *testing.T) {
ssv := "UID=iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=;PublicKey=IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=;" + ssv := "UID=iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=;PublicKey=IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=;ServerName=www.bing.com;NumConn=4;MaskBrowser=chrome;"
"ServerName=www.bing.com;NumConn=4;MaskBrowser=chrome;ProxyMethod=shadowsocks;EncryptionMethod=plain"
json := ssvToJson(ssv) json := ssvToJson(ssv)
expected := []byte(`{"UID":"iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=","PublicKey":"IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=","ServerName":"www.bing.com","NumConn":4,"MaskBrowser":"chrome","ProxyMethod":"shadowsocks","EncryptionMethod":"plain"}`) expected := []byte(`{"UID":"iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=","PublicKey":"IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=","ServerName":"www.bing.com","NumConn":4,"MaskBrowser":"chrome"}`)
if !bytes.Equal(expected, json) {
t.Run("byte equality", func(t *testing.T) { t.Error(
assert.Equal(t, expected, json) "For", "ssvToJson",
}) "expecting", string(expected),
"got", string(json),
t.Run("struct equality", func(t *testing.T) { )
tmpConfig, _ := ioutil.TempFile("", "ck_client_config") }
_, _ = tmpConfig.Write(expected)
parsedFromSSV, err := ParseConfig(ssv)
assert.NoError(t, err)
parsedFromJson, err := ParseConfig(tmpConfig.Name())
assert.NoError(t, err)
assert.Equal(t, parsedFromJson, parsedFromSSV)
})
t.Run("empty file", func(t *testing.T) {
tmpConfig, _ := ioutil.TempFile("", "ck_client_config")
_, err := ParseConfig(tmpConfig.Name())
assert.Error(t, err)
})
} }

View File

@ -1,33 +1,9 @@
package client package client
import ( import "net"
"net"
)
type Transport interface { type Transport interface {
Handshake(rawConn net.Conn, authInfo AuthInfo) (sessionKey [32]byte, err error) PrepareConnection(*State, net.Conn) (net.Conn, []byte, error)
net.Conn HasRecordLayer() bool
} UnitReadFunc() func(net.Conn, []byte) (int, error)
type TransportConfig struct {
mode string
wsUrl string
browser browser
}
func (t TransportConfig) CreateTransport() Transport {
switch t.mode {
case "cdn":
return &WSOverTLS{
wsUrl: t.wsUrl,
}
case "direct":
return &DirectTLS{
browser: t.browser,
}
default:
return nil
}
} }

View File

@ -4,81 +4,61 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"github.com/cbeuw/Cloak/internal/util"
"github.com/gorilla/websocket"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"github.com/cbeuw/Cloak/internal/common"
"github.com/gorilla/websocket"
utls "github.com/refraction-networking/utls" utls "github.com/refraction-networking/utls"
) )
type WSOverTLS struct { type WSOverTLS struct {
*common.WebSocketConn Transport
wsUrl string
} }
func (ws *WSOverTLS) Handshake(rawConn net.Conn, authInfo AuthInfo) (sessionKey [32]byte, err error) { func (WSOverTLS) HasRecordLayer() bool { return false }
func (WSOverTLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket }
func (WSOverTLS) PrepareConnection(sta *State, conn net.Conn) (preparedConn net.Conn, sessionKey []byte, err error) {
utlsConfig := &utls.Config{ utlsConfig := &utls.Config{
ServerName: authInfo.MockDomain, ServerName: sta.ServerName,
InsecureSkipVerify: true, InsecureSkipVerify: true,
} }
uconn := utls.UClient(rawConn, utlsConfig, utls.HelloChrome_Auto) uconn := utls.UClient(conn, utlsConfig, utls.HelloChrome_Auto)
err = uconn.BuildHandshakeState()
if err != nil {
return
}
for i, extension := range uconn.Extensions {
_, ok := extension.(*utls.ALPNExtension)
if ok {
uconn.Extensions = append(uconn.Extensions[:i], uconn.Extensions[i+1:]...)
break
}
}
err = uconn.Handshake() err = uconn.Handshake()
preparedConn = uconn
if err != nil { if err != nil {
return return
} }
u, err := url.Parse(ws.wsUrl) u, err := url.Parse("ws://" + sta.RemoteHost + ":" + sta.RemotePort) //TODO IPv6
if err != nil { if err != nil {
return sessionKey, fmt.Errorf("failed to parse ws url: %v", err) return preparedConn, nil, fmt.Errorf("failed to parse ws url: %v", err)
} }
payload, sharedSecret := makeAuthenticationPayload(authInfo) hd, sharedSecret := makeHiddenData(sta)
header := http.Header{} header := http.Header{}
header.Add("hidden", base64.StdEncoding.EncodeToString(append(payload.randPubKey[:], payload.ciphertextWithTag[:]...))) header.Add("hidden", base64.StdEncoding.EncodeToString(hd.fullRaw))
c, _, err := websocket.NewClient(uconn, u, header, 16480, 16480) c, _, err := websocket.NewClient(preparedConn, u, header, 16480, 16480)
if err != nil { if err != nil {
return sessionKey, fmt.Errorf("failed to handshake: %v", err) return preparedConn, nil, fmt.Errorf("failed to handshake: %v", err)
} }
ws.WebSocketConn = &common.WebSocketConn{Conn: c} preparedConn = &util.WebSocketConn{Conn: c}
buf := make([]byte, 128) buf := make([]byte, 128)
n, err := ws.Read(buf) n, err := preparedConn.Read(buf)
if err != nil { if err != nil {
return sessionKey, fmt.Errorf("failed to read reply: %v", err) return preparedConn, nil, fmt.Errorf("failed to read reply: %v", err)
} }
if n != 60 { if n != 60 {
return sessionKey, errors.New("reply must be 60 bytes") return preparedConn, nil, errors.New("reply must be 60 bytes")
} }
reply := buf[:60] reply := buf[:60]
sessionKeySlice, err := common.AESGCMDecrypt(reply[:12], sharedSecret[:], reply[12:]) sessionKey, err = util.AESGCMDecrypt(reply[:12], sharedSecret, reply[12:])
if err != nil {
return
}
copy(sessionKey[:], sessionKeySlice)
return return
} }
func (ws *WSOverTLS) Close() error {
if ws.WebSocketConn != nil {
return ws.WebSocketConn.Close()
}
return nil
}

View File

@ -1,79 +0,0 @@
/*
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/*
Forked from https://golang.org/src/io/io.go
*/
package common
import (
"io"
"net"
)
func Copy(dst net.Conn, src net.Conn) (written int64, err error) {
defer func() { src.Close(); dst.Close() }()
// If the reader has a WriteTo method, use it to do the copy.
// Avoids an allocation and a copy.
if wt, ok := src.(io.WriterTo); ok {
return wt.WriteTo(dst)
}
// Similarly, if the writer has a ReadFrom method, use it to do the copy.
if rt, ok := dst.(io.ReaderFrom); ok {
return rt.ReadFrom(src)
}
size := 32 * 1024
buf := make([]byte, size)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}

View File

@ -1,97 +0,0 @@
package common
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"io"
"math/big"
"time"
log "github.com/sirupsen/logrus"
)
func AESGCMEncrypt(nonce []byte, key []byte, plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != aesgcm.NonceSize() {
// check here so it doesn't panic
return nil, errors.New("incorrect nonce size")
}
return aesgcm.Seal(nil, nonce, plaintext, nil), nil
}
func AESGCMDecrypt(nonce []byte, key []byte, ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != aesgcm.NonceSize() {
// check here so it doesn't panic
return nil, errors.New("incorrect nonce size")
}
plain, err := aesgcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plain, nil
}
func CryptoRandRead(buf []byte) {
RandRead(rand.Reader, buf)
}
func backoff(f func() error) {
err := f()
if err == nil {
return
}
waitDur := [10]time.Duration{5 * time.Millisecond, 10 * time.Millisecond, 30 * time.Millisecond, 50 * time.Millisecond,
100 * time.Millisecond, 300 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second,
3 * time.Second, 5 * time.Second}
for i := 0; i < 10; i++ {
log.Errorf("Failed to get random: %v. Retrying...", err)
err = f()
if err == nil {
return
}
time.Sleep(waitDur[i])
}
log.Fatal("Cannot get random after 10 retries")
}
func RandRead(randSource io.Reader, buf []byte) {
backoff(func() error {
_, err := randSource.Read(buf)
return err
})
}
func RandItem[T any](list []T) T {
return list[RandInt(len(list))]
}
func RandInt(n int) int {
s := new(int)
backoff(func() error {
size, err := rand.Int(rand.Reader, big.NewInt(int64(n)))
if err != nil {
return err
}
*s = int(size.Int64())
return nil
})
return *s
}

View File

@ -1,95 +0,0 @@
package common
import (
"bytes"
"encoding/hex"
"errors"
"io"
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
)
const gcmTagSize = 16
func TestAESGCM(t *testing.T) {
// test vectors from https://luca-giuzzi.unibs.it/corsi/Support/papers-cryptography/gcm-spec.pdf
t.Run("correct 128", func(t *testing.T) {
key, _ := hex.DecodeString("00000000000000000000000000000000")
plaintext, _ := hex.DecodeString("")
nonce, _ := hex.DecodeString("000000000000000000000000")
ciphertext, _ := hex.DecodeString("")
tag, _ := hex.DecodeString("58e2fccefa7e3061367f1d57a4e7455a")
encryptedWithTag, err := AESGCMEncrypt(nonce, key, plaintext)
assert.NoError(t, err)
assert.Equal(t, ciphertext, encryptedWithTag[:len(plaintext)])
assert.Equal(t, tag, encryptedWithTag[len(plaintext):len(plaintext)+gcmTagSize])
decrypted, err := AESGCMDecrypt(nonce, key, encryptedWithTag)
assert.NoError(t, err)
// slight inconvenience here that assert.Equal does not consider a nil slice and an empty slice to be
// equal. decrypted should be []byte(nil) but plaintext is []byte{}
assert.True(t, bytes.Equal(plaintext, decrypted))
})
t.Run("bad key size", func(t *testing.T) {
key, _ := hex.DecodeString("0000000000000000000000000000")
plaintext, _ := hex.DecodeString("")
nonce, _ := hex.DecodeString("000000000000000000000000")
ciphertext, _ := hex.DecodeString("")
tag, _ := hex.DecodeString("58e2fccefa7e3061367f1d57a4e7455a")
_, err := AESGCMEncrypt(nonce, key, plaintext)
assert.Error(t, err)
_, err = AESGCMDecrypt(nonce, key, append(ciphertext, tag...))
assert.Error(t, err)
})
t.Run("bad nonce size", func(t *testing.T) {
key, _ := hex.DecodeString("00000000000000000000000000000000")
plaintext, _ := hex.DecodeString("")
nonce, _ := hex.DecodeString("00000000000000000000")
ciphertext, _ := hex.DecodeString("")
tag, _ := hex.DecodeString("58e2fccefa7e3061367f1d57a4e7455a")
_, err := AESGCMEncrypt(nonce, key, plaintext)
assert.Error(t, err)
_, err = AESGCMDecrypt(nonce, key, append(ciphertext, tag...))
assert.Error(t, err)
})
t.Run("bad tag", func(t *testing.T) {
key, _ := hex.DecodeString("00000000000000000000000000000000")
nonce, _ := hex.DecodeString("00000000000000000000")
ciphertext, _ := hex.DecodeString("")
tag, _ := hex.DecodeString("fffffccefa7e3061367f1d57a4e745ff")
_, err := AESGCMDecrypt(nonce, key, append(ciphertext, tag...))
assert.Error(t, err)
})
}
type failingReader struct {
fails int
reader io.Reader
}
func (f *failingReader) Read(p []byte) (n int, err error) {
if f.fails > 0 {
f.fails -= 1
return 0, errors.New("no data for you yet")
} else {
return f.reader.Read(p)
}
}
func TestRandRead(t *testing.T) {
failer := &failingReader{
fails: 3,
reader: rand.New(rand.NewSource(0)),
}
readBuf := make([]byte, 10)
RandRead(failer, readBuf)
assert.NotEqual(t, [10]byte{}, readBuf)
}

View File

@ -1,7 +0,0 @@
package common
import "net"
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}

View File

@ -1,112 +0,0 @@
package common
import (
"encoding/binary"
"errors"
"io"
"net"
"sync"
"time"
)
const (
VersionTLS11 = 0x0301
VersionTLS13 = 0x0303
recordLayerLength = 5
Handshake = 22
ApplicationData = 23
initialWriteBufSize = 14336
)
func AddRecordLayer(input []byte, typ byte, ver uint16) []byte {
msgLen := len(input)
retLen := msgLen + recordLayerLength
var ret []byte
ret = make([]byte, retLen)
copy(ret[recordLayerLength:], input)
ret[0] = typ
ret[1] = byte(ver >> 8)
ret[2] = byte(ver)
ret[3] = byte(msgLen >> 8)
ret[4] = byte(msgLen)
return ret
}
type TLSConn struct {
net.Conn
writeBufPool sync.Pool
}
func NewTLSConn(conn net.Conn) *TLSConn {
return &TLSConn{
Conn: conn,
writeBufPool: sync.Pool{New: func() interface{} {
b := make([]byte, 0, initialWriteBufSize)
b = append(b, ApplicationData, byte(VersionTLS13>>8), byte(VersionTLS13&0xFF))
return &b
}},
}
}
func (tls *TLSConn) LocalAddr() net.Addr {
return tls.Conn.LocalAddr()
}
func (tls *TLSConn) RemoteAddr() net.Addr {
return tls.Conn.RemoteAddr()
}
func (tls *TLSConn) SetDeadline(t time.Time) error {
return tls.Conn.SetDeadline(t)
}
func (tls *TLSConn) SetReadDeadline(t time.Time) error {
return tls.Conn.SetReadDeadline(t)
}
func (tls *TLSConn) SetWriteDeadline(t time.Time) error {
return tls.Conn.SetWriteDeadline(t)
}
func (tls *TLSConn) Read(buffer []byte) (n int, err error) {
// TCP is a stream. Multiple TLS messages can arrive at the same time,
// a single message can also be segmented due to MTU of the IP layer.
// This function guareentees a single TLS message to be read and everything
// else is left in the buffer.
if len(buffer) < recordLayerLength {
return 0, io.ErrShortBuffer
}
_, err = io.ReadFull(tls.Conn, buffer[:recordLayerLength])
if err != nil {
return
}
dataLength := int(binary.BigEndian.Uint16(buffer[3:5]))
if dataLength > len(buffer) {
err = io.ErrShortBuffer
return
}
// we overwrite the record layer here
return io.ReadFull(tls.Conn, buffer[:dataLength])
}
func (tls *TLSConn) Write(in []byte) (n int, err error) {
msgLen := len(in)
if msgLen > 1<<14+256 { // https://tools.ietf.org/html/rfc8446#section-5.2
return 0, errors.New("message is too long")
}
writeBuf := tls.writeBufPool.Get().(*[]byte)
*writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF))
*writeBuf = append(*writeBuf, in...)
n, err = tls.Conn.Write(*writeBuf)
*writeBuf = (*writeBuf)[:3]
tls.writeBufPool.Put(writeBuf)
return n - recordLayerLength, err
}
func (tls *TLSConn) Close() error {
return tls.Conn.Close()
}

View File

@ -1,40 +0,0 @@
package common
import (
"net"
"testing"
)
func BenchmarkTLSConn_Write(b *testing.B) {
const bufSize = 16 * 1024
addrCh := make(chan string, 1)
go func() {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatal(err)
}
addrCh <- listener.Addr().String()
conn, err := listener.Accept()
if err != nil {
b.Fatal(err)
}
readBuf := make([]byte, bufSize*2)
for {
_, err = conn.Read(readBuf)
if err != nil {
return
}
}
}()
data := make([]byte, bufSize)
discardConn, _ := net.Dial("tcp", <-addrCh)
tlsConn := NewTLSConn(discardConn)
defer tlsConn.Close()
b.SetBytes(bufSize)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
tlsConn.Write(data)
}
})
}

View File

@ -1,24 +0,0 @@
package common
import (
"crypto/rand"
"io"
"time"
)
var RealWorldState = WorldState{
Rand: rand.Reader,
Now: time.Now,
}
type WorldState struct {
Rand io.Reader
Now func() time.Time
}
func WorldOfTime(t time.Time) WorldState {
return WorldState{
Rand: rand.Reader,
Now: func() time.Time { return t },
}
}

View File

@ -68,11 +68,13 @@ func Unmarshal(data []byte) (crypto.PublicKey, bool) {
return &pub, true return &pub, true
} }
func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) ([]byte, error) { func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) []byte {
var priv, pub *[32]byte var priv, pub, secret *[32]byte
priv = privKey.(*[32]byte) priv = privKey.(*[32]byte)
pub = pubKey.(*[32]byte) pub = pubKey.(*[32]byte)
secret = new([32]byte)
return curve25519.X25519(priv[:], pub[:]) curve25519.ScalarMult(secret, priv, pub)
return secret[:]
} }

View File

@ -32,7 +32,6 @@ import (
"bytes" "bytes"
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"io"
"testing" "testing"
) )
@ -40,20 +39,6 @@ func TestCurve25519(t *testing.T) {
testECDH(t) testECDH(t)
} }
func TestErrors(t *testing.T) {
reader, writer := io.Pipe()
_ = writer.Close()
_, _, err := GenerateKey(reader)
if err == nil {
t.Error("GenerateKey should return error")
}
_, ok := Unmarshal([]byte{1})
if ok {
t.Error("Unmarshal should return false")
}
}
func BenchmarkCurve25519(b *testing.B) { func BenchmarkCurve25519(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
testECDH(b) testECDH(b)
@ -90,11 +75,11 @@ func testECDH(t testing.TB) {
t.Fatalf("Unmarshal does not work") t.Fatalf("Unmarshal does not work")
} }
secret1, err = GenerateSharedSecret(privKey1, pubKey2) secret1 = GenerateSharedSecret(privKey1, pubKey2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
secret2, err = GenerateSharedSecret(privKey2, pubKey1) secret2 = GenerateSharedSecret(privKey2, pubKey1)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }

View File

@ -6,30 +6,26 @@ import (
"bytes" "bytes"
"io" "io"
"sync" "sync"
"time"
) )
// The point of a streamBufferedPipe is that Read() will block until data is available const BUF_SIZE_LIMIT = 1 << 20 * 500
type streamBufferedPipe struct {
buf *bytes.Buffer
// The point of a bufferedPipe is that Read() will block until data is available
type bufferedPipe struct {
buf *bytes.Buffer
closed bool closed bool
rwCond *sync.Cond rwCond *sync.Cond
rDeadline time.Time
wtTimeout time.Duration
timeoutTimer *time.Timer
} }
func NewStreamBufferedPipe() *streamBufferedPipe { func NewBufferedPipe() *bufferedPipe {
p := &streamBufferedPipe{ p := &bufferedPipe{
rwCond: sync.NewCond(&sync.Mutex{}),
buf: new(bytes.Buffer), buf: new(bytes.Buffer),
rwCond: sync.NewCond(&sync.Mutex{}),
} }
return p return p
} }
func (p *streamBufferedPipe) Read(target []byte) (int, error) { func (p *bufferedPipe) Read(target []byte) (int, error) {
p.rwCond.L.Lock() p.rwCond.L.Lock()
defer p.rwCond.L.Unlock() defer p.rwCond.L.Unlock()
for { for {
@ -37,19 +33,9 @@ func (p *streamBufferedPipe) Read(target []byte) (int, error) {
return 0, io.EOF return 0, io.EOF
} }
hasRDeadline := !p.rDeadline.IsZero()
if hasRDeadline {
if time.Until(p.rDeadline) <= 0 {
return 0, ErrTimeout
}
}
if p.buf.Len() > 0 { if p.buf.Len() > 0 {
break break
} }
if hasRDeadline {
p.broadcastAfter(time.Until(p.rDeadline))
}
p.rwCond.Wait() p.rwCond.Wait()
} }
n, err := p.buf.Read(target) n, err := p.buf.Read(target)
@ -58,14 +44,14 @@ func (p *streamBufferedPipe) Read(target []byte) (int, error) {
return n, err return n, err
} }
func (p *streamBufferedPipe) Write(input []byte) (int, error) { func (p *bufferedPipe) Write(input []byte) (int, error) {
p.rwCond.L.Lock() p.rwCond.L.Lock()
defer p.rwCond.L.Unlock() defer p.rwCond.L.Unlock()
for { for {
if p.closed { if p.closed {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
if p.buf.Len() <= recvBufferSizeLimit { if p.buf.Len() <= BUF_SIZE_LIMIT {
// if p.buf gets too large, write() will panic. We don't want this to happen // if p.buf gets too large, write() will panic. We don't want this to happen
break break
} }
@ -77,7 +63,7 @@ func (p *streamBufferedPipe) Write(input []byte) (int, error) {
return n, err return n, err
} }
func (p *streamBufferedPipe) Close() error { func (p *bufferedPipe) Close() error {
p.rwCond.L.Lock() p.rwCond.L.Lock()
defer p.rwCond.L.Unlock() defer p.rwCond.L.Unlock()
@ -86,17 +72,8 @@ func (p *streamBufferedPipe) Close() error {
return nil return nil
} }
func (p *streamBufferedPipe) SetReadDeadline(t time.Time) { func (p *bufferedPipe) Len() int {
p.rwCond.L.Lock() p.rwCond.L.Lock()
defer p.rwCond.L.Unlock() defer p.rwCond.L.Unlock()
return p.buf.Len()
p.rDeadline = t
p.rwCond.Broadcast()
}
func (p *streamBufferedPipe) broadcastAfter(d time.Duration) {
if p.timeoutTimer != nil {
p.timeoutTimer.Stop()
}
p.timeoutTimer = time.AfterFunc(d, p.rwCond.Broadcast)
} }

View File

@ -0,0 +1,207 @@
package multiplex
import (
"bytes"
"math/rand"
"testing"
"time"
)
func TestPipeRW(t *testing.T) {
pipe := NewBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
n, err := pipe.Write(b)
if n != len(b) {
t.Error(
"For", "number of bytes written",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "simple write",
"expecting", "nil error",
"got", err,
)
return
}
b2 := make([]byte, len(b))
n, err = pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
}
}
func TestReadBlock(t *testing.T) {
pipe := NewBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
go func() {
time.Sleep(10 * time.Millisecond)
pipe.Write(b)
}()
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read after block",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "blocked read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "blocked read",
"expecting", b,
"got", b2,
)
return
}
}
func TestPartialRead(t *testing.T) {
pipe := NewBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(b)
b1 := make([]byte, 1)
n, err := pipe.Read(b1)
if n != len(b1) {
t.Error(
"For", "number of bytes in partial read of 1",
"expecting", len(b1),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "partial read of 1",
"expecting", "nil error",
"got", err,
)
return
}
if b1[0] != b[0] {
t.Error(
"For", "partial read of 1",
"expecting", b[0],
"got", b1[0],
)
}
b2 := make([]byte, 2)
n, err = pipe.Read(b2)
if n != len(b2) {
t.Error(
"For", "number of bytes in partial read of 2",
"expecting", len(b2),
"got", n,
)
}
if err != nil {
t.Error(
"For", "partial read of 2",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b[1:], b2) {
t.Error(
"For", "partial read of 2",
"expecting", b[1:],
"got", b2,
)
return
}
}
func TestReadAfterClose(t *testing.T) {
pipe := NewBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(b)
b2 := make([]byte, len(b))
pipe.Close()
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
return
}
}
func BenchmarkBufferedPipe_RW(b *testing.B) {
const PAYLOAD_LEN = 1000
testData := make([]byte, PAYLOAD_LEN)
rand.Read(testData)
pipe := NewBufferedPipe()
smallBuf := make([]byte, PAYLOAD_LEN-10)
go func() {
for {
pipe.Read(smallBuf)
}
}()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := pipe.Write(testData)
if err != nil {
b.Error(
"For", "pipe write",
"got", err,
)
return
}
b.SetBytes(PAYLOAD_LEN)
}
}

View File

@ -0,0 +1,89 @@
// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173
package multiplex
import (
"errors"
"io"
"sync"
)
const DATAGRAM_NUMBER_LIMIT = 1024
// datagramBuffer is the same as bufferedPipe with the exception that it's message-oriented,
// instead of byte-oriented. The integrity of datagrams written into this buffer is preserved.
// it won't get chopped up into individual bytes
type datagramBuffer struct {
buf [][]byte
closed bool
rwCond *sync.Cond
}
func NewDatagramBuffer() *datagramBuffer {
d := &datagramBuffer{
buf: make([][]byte, 0),
rwCond: sync.NewCond(&sync.Mutex{}),
}
return d
}
func (d *datagramBuffer) Read(target []byte) (int, error) {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
for {
if d.closed && len(d.buf) == 0 {
return 0, io.EOF
}
if len(d.buf) > 0 {
break
}
d.rwCond.Wait()
}
data := d.buf[0]
if len(target) < len(data) {
return 0, errors.New("buffer is too small")
}
d.buf = d.buf[1:]
copy(target, data)
// err will always be nil because we have already verified that buf.Len() != 0
d.rwCond.Broadcast()
return len(data), nil
}
func (d *datagramBuffer) Write(f Frame) error {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
for {
if d.closed {
return io.ErrClosedPipe
}
if len(d.buf) <= DATAGRAM_NUMBER_LIMIT {
// if d.buf gets too large, write() will panic. We don't want this to happen
break
}
d.rwCond.Wait()
}
if f.Closing == 1 {
d.closed = true
d.rwCond.Broadcast()
return nil
}
data := make([]byte, len(f.Payload))
copy(data, f.Payload)
d.buf = append(d.buf, data)
// err will always be nil
d.rwCond.Broadcast()
return nil
}
func (d *datagramBuffer) Close() error {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
d.closed = true
d.rwCond.Broadcast()
return nil
}

View File

@ -0,0 +1,137 @@
package multiplex
import (
"bytes"
"testing"
"time"
)
func TestDatagramBuffer_RW(t *testing.T) {
b := []byte{0x01, 0x02, 0x03}
t.Run("simple write", func(t *testing.T) {
pipe := NewDatagramBuffer()
err := pipe.Write(Frame{Payload: b})
if err != nil {
t.Error(
"expecting", "nil error",
"got", err,
)
return
}
})
t.Run("simple read", func(t *testing.T) {
pipe := NewDatagramBuffer()
_ = pipe.Write(Frame{Payload: b})
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"expecting", b,
"got", b2,
)
}
if len(pipe.buf) != 0 {
t.Error("buf len is not 0 after finished reading")
return
}
})
t.Run("writing closing frame", func(t *testing.T) {
pipe := NewDatagramBuffer()
err := pipe.Write(Frame{Closing: 1})
if err != nil {
t.Error(
"expecting", "nil error",
"got", err,
)
return
}
if !pipe.closed {
t.Error("expecting closed pipe, not closed")
}
})
}
func TestDatagramBuffer_BlockingRead(t *testing.T) {
pipe := NewDatagramBuffer()
b := []byte{0x01, 0x02, 0x03}
go func() {
time.Sleep(10 * time.Millisecond)
pipe.Write(Frame{Payload: b})
}()
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read after block",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "blocked read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "blocked read",
"expecting", b,
"got", b2,
)
return
}
}
func TestDatagramBuffer_CloseThenRead(t *testing.T) {
pipe := NewDatagramBuffer()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(Frame{Payload: b})
b2 := make([]byte, len(b))
pipe.Close()
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
return
}
}

View File

@ -1,119 +0,0 @@
// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173
package multiplex
import (
"bytes"
"io"
"sync"
"time"
)
// datagramBufferedPipe is the same as streamBufferedPipe with the exception that it's message-oriented,
// instead of byte-oriented. The integrity of datagrams written into this buffer is preserved.
// it won't get chopped up into individual bytes
type datagramBufferedPipe struct {
pLens []int
buf *bytes.Buffer
closed bool
rwCond *sync.Cond
wtTimeout time.Duration
rDeadline time.Time
timeoutTimer *time.Timer
}
func NewDatagramBufferedPipe() *datagramBufferedPipe {
d := &datagramBufferedPipe{
rwCond: sync.NewCond(&sync.Mutex{}),
buf: new(bytes.Buffer),
}
return d
}
func (d *datagramBufferedPipe) Read(target []byte) (int, error) {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
for {
if d.closed && len(d.pLens) == 0 {
return 0, io.EOF
}
hasRDeadline := !d.rDeadline.IsZero()
if hasRDeadline {
if time.Until(d.rDeadline) <= 0 {
return 0, ErrTimeout
}
}
if len(d.pLens) > 0 {
break
}
if hasRDeadline {
d.broadcastAfter(time.Until(d.rDeadline))
}
d.rwCond.Wait()
}
dataLen := d.pLens[0]
if len(target) < dataLen {
return 0, io.ErrShortBuffer
}
d.pLens = d.pLens[1:]
d.buf.Read(target[:dataLen])
// err will always be nil because we have already verified that buf.Len() != 0
d.rwCond.Broadcast()
return dataLen, nil
}
func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
for {
if d.closed {
return true, io.ErrClosedPipe
}
if d.buf.Len() <= recvBufferSizeLimit {
// if d.buf gets too large, write() will panic. We don't want this to happen
break
}
d.rwCond.Wait()
}
if f.Closing != closingNothing {
d.closed = true
d.rwCond.Broadcast()
return true, nil
}
dataLen := len(f.Payload)
d.pLens = append(d.pLens, dataLen)
d.buf.Write(f.Payload)
// err will always be nil
d.rwCond.Broadcast()
return false, nil
}
func (d *datagramBufferedPipe) Close() error {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
d.closed = true
d.rwCond.Broadcast()
return nil
}
func (d *datagramBufferedPipe) SetReadDeadline(t time.Time) {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
d.rDeadline = t
d.rwCond.Broadcast()
}
func (d *datagramBufferedPipe) broadcastAfter(t time.Duration) {
if d.timeoutTimer != nil {
d.timeoutTimer.Stop()
}
d.timeoutTimer = time.AfterFunc(t, d.rwCond.Broadcast)
}

View File

@ -1,62 +0,0 @@
package multiplex
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestDatagramBuffer_RW(t *testing.T) {
b := []byte{0x01, 0x02, 0x03}
t.Run("simple write", func(t *testing.T) {
pipe := NewDatagramBufferedPipe()
_, err := pipe.Write(&Frame{Payload: b})
assert.NoError(t, err)
})
t.Run("simple read", func(t *testing.T) {
pipe := NewDatagramBufferedPipe()
_, _ = pipe.Write(&Frame{Payload: b})
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
assert.NoError(t, err)
assert.Equal(t, len(b), n)
assert.Equal(t, b, b2)
assert.Equal(t, 0, pipe.buf.Len(), "buf len is not 0 after finished reading")
})
t.Run("writing closing frame", func(t *testing.T) {
pipe := NewDatagramBufferedPipe()
toBeClosed, err := pipe.Write(&Frame{Closing: closingStream})
assert.NoError(t, err)
assert.True(t, toBeClosed, "should be to be closed")
assert.True(t, pipe.closed, "pipe should be closed")
})
}
func TestDatagramBuffer_BlockingRead(t *testing.T) {
pipe := NewDatagramBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
go func() {
time.Sleep(readBlockTime)
pipe.Write(&Frame{Payload: b})
}()
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
assert.NoError(t, err)
assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
assert.Equal(t, b, b2)
}
func TestDatagramBuffer_CloseThenRead(t *testing.T) {
pipe := NewDatagramBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(&Frame{Payload: b})
b2 := make([]byte, len(b))
pipe.Close()
n, err := pipe.Read(b2)
assert.NoError(t, err)
assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
assert.Equal(t, b, b2)
}

View File

@ -1,11 +1,5 @@
package multiplex package multiplex
const (
closingNothing = iota
closingStream
closingSession
)
type Frame struct { type Frame struct {
StreamID uint32 StreamID uint32
Seq uint64 Seq uint64

View File

@ -1,150 +0,0 @@
package multiplex
import (
"bytes"
"io"
"math/rand"
"net"
"sync"
"testing"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/connutil"
"github.com/stretchr/testify/assert"
)
func serveEcho(l net.Listener) {
for {
conn, err := l.Accept()
if err != nil {
// TODO: pass the error back
return
}
go func(conn net.Conn) {
_, err := io.Copy(conn, conn)
if err != nil {
// TODO: pass the error back
return
}
}(conn)
}
}
type connPair struct {
clientConn net.Conn
serverConn net.Conn
}
func makeSessionPair(numConn int) (*Session, *Session, []*connPair) {
sessionKey := [32]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}
sessionId := 1
obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey)
clientConfig := SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
Unordered: false,
}
serverConfig := clientConfig
clientSession := MakeSession(uint32(sessionId), clientConfig)
serverSession := MakeSession(uint32(sessionId), serverConfig)
paris := make([]*connPair, numConn)
for i := 0; i < numConn; i++ {
c, s := connutil.AsyncPipe()
clientConn := common.NewTLSConn(c)
serverConn := common.NewTLSConn(s)
paris[i] = &connPair{
clientConn: clientConn,
serverConn: serverConn,
}
clientSession.AddConnection(clientConn)
serverSession.AddConnection(serverConn)
}
return clientSession, serverSession, paris
}
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
var wg sync.WaitGroup
for _, conn := range conns {
wg.Add(1)
go func(conn net.Conn) {
defer wg.Done()
testData := make([]byte, msgLen)
rand.Read(testData)
// we cannot call t.Fatalf in concurrent contexts
n, err := conn.Write(testData)
if n != msgLen {
t.Errorf("written only %v, err %v", n, err)
return
}
recvBuf := make([]byte, msgLen)
_, err = io.ReadFull(conn, recvBuf)
if err != nil {
t.Errorf("failed to read back: %v", err)
return
}
if !bytes.Equal(testData, recvBuf) {
t.Errorf("echoed data not correct")
return
}
}(conn)
}
wg.Wait()
}
func TestMultiplex(t *testing.T) {
const numStreams = 2000 // -race option limits the number of goroutines to 8192
const numConns = 4
const msgLen = 16384
clientSession, serverSession, _ := makeSessionPair(numConns)
go serveEcho(serverSession)
streams := make([]net.Conn, numStreams)
for i := 0; i < numStreams; i++ {
stream, err := clientSession.OpenStream()
assert.NoError(t, err)
streams[i] = stream
}
//test echo
runEchoTest(t, streams, msgLen)
assert.EqualValues(t, numStreams, clientSession.streamCount(), "client stream count is wrong")
assert.EqualValues(t, numStreams, serverSession.streamCount(), "server stream count is wrong")
// close one stream
closing, streams := streams[0], streams[1:]
err := closing.Close()
assert.NoError(t, err, "couldn't close a stream")
_, err = closing.Write([]byte{0})
assert.Equal(t, ErrBrokenStream, err)
_, err = closing.Read(make([]byte, 1))
assert.Equal(t, ErrBrokenStream, err)
}
func TestMux_StreamClosing(t *testing.T) {
clientSession, serverSession, _ := makeSessionPair(1)
go serveEcho(serverSession)
// read after closing stream
testData := make([]byte, 128)
recvBuf := make([]byte, 128)
toBeClosed, _ := clientSession.OpenStream()
_, err := toBeClosed.Write(testData) // should be echoed back
assert.NoError(t, err, "couldn't write to a stream")
_, err = io.ReadFull(toBeClosed, recvBuf[:1])
assert.NoError(t, err, "can't read anything before stream closed")
_ = toBeClosed.Close()
_, err = io.ReadFull(toBeClosed, recvBuf[1:])
assert.NoError(t, err, "can't read residual data on stream")
assert.Equal(t, testData, recvBuf, "incorrect data read back")
}

View File

@ -7,195 +7,178 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/cbeuw/Cloak/internal/common"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/salsa20" "golang.org/x/crypto/salsa20"
) )
const frameHeaderLength = 14 type Obfser func(*Frame, []byte) (int, error)
const salsa20NonceSize = 8 type Deobfser func([]byte) (*Frame, error)
// maxExtraLen equals the max length of padding + AEAD tag. var u32 = binary.BigEndian.Uint32
// It is 255 bytes because the extra len field in frame header is only one byte. var u64 = binary.BigEndian.Uint64
const maxExtraLen = 1<<8 - 1 var putU32 = binary.BigEndian.PutUint32
var putU64 = binary.BigEndian.PutUint64
// padFirstNFrames specifies the number of initial frames to pad, const HEADER_LEN = 14
// to avoid TLS-in-TLS detection
const padFirstNFrames = 5
const ( const (
EncryptionMethodPlain = iota E_METHOD_PLAIN = iota
EncryptionMethodAES256GCM E_METHOD_AES_GCM
EncryptionMethodChaha20Poly1305 E_METHOD_CHACHA20_POLY1305
EncryptionMethodAES128GCM
) )
// Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames. func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD, hasRecordLayer bool) Obfser {
type Obfuscator struct { var rlLen int
payloadCipher cipher.AEAD if hasRecordLayer {
rlLen = 5
sessionKey [32]byte }
} obfs := func(f *Frame, buf []byte) (int, error) {
// we need the encrypted data to be at least 8 bytes to be used as nonce for salsa20 stream header encryption
// obfuscate adds multiplexing headers, encrypt and add TLS header // this will be the case if the encryption method is an AEAD cipher, however for plain, it's well possible
func (o *Obfuscator) obfuscate(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) { // that the frame payload is smaller than 8 bytes, so we need to add on the difference
// The method here is to use the first payloadCipher.NonceSize() bytes of the serialised frame header var extraLen uint8
// as iv/nonce for the AEAD cipher to encrypt the frame payload. Then we use if payloadCipher == nil {
// the authentication tag produced appended to the end of the ciphertext (of size payloadCipher.Overhead()) if len(f.Payload) < 8 {
// as nonce for Salsa20 to encrypt the frame header. Both with sessionKey as keys. extraLen = uint8(8 - len(f.Payload))
//
// Several cryptographic guarantees we have made here: that payloadCipher, as an AEAD, is given a unique
// iv/nonce each time, relative to its key; that the frame header encryptor Salsa20 is given a unique
// nonce each time, relative to its key; and that the authenticity of frame header is checked.
//
// The payloadCipher is given a unique iv/nonce each time because it is derived from the frame header, which
// contains the monotonically increasing stream id (uint32) and frame sequence (uint64). There will be a nonce
// reuse after 2^64-1 frames sent (sent, not received because frames going different ways are sequenced
// independently) by a stream, or after 2^32-1 streams created in a single session. We consider these number
// to be large enough that they may never happen in reasonable time frames. Of course, different sessions
// will produce the same combination of stream id and frame sequence, but they will have different session keys.
//
//
// Because the frame header, before it being encrypted, is fed into the AEAD, it is also authenticated.
// (rfc5116 s.2.1 "The nonce is authenticated internally to the algorithm").
//
// In case the user chooses to not encrypt the frame payload, payloadCipher will be nil. In this scenario,
// we generate random bytes to be used as salsa20 nonce.
payloadLen := len(f.Payload)
if payloadLen == 0 {
return 0, errors.New("payload cannot be empty")
} }
tagLen := 0
if o.payloadCipher != nil {
tagLen = o.payloadCipher.Overhead()
} else { } else {
tagLen = salsa20NonceSize extraLen = uint8(payloadCipher.Overhead())
}
// Pad to avoid size side channel leak
padLen := 0
if f.Seq < padFirstNFrames {
padLen = common.RandInt(maxExtraLen - tagLen + 1)
} }
usefulLen := frameHeaderLength + payloadLen + padLen + tagLen // usefulLen is the amount of bytes that will be eventually sent off
usefulLen := rlLen + HEADER_LEN + len(f.Payload) + int(extraLen)
if len(buf) < usefulLen { if len(buf) < usefulLen {
return 0, errors.New("obfs buffer too small") return 0, errors.New("buffer is too small")
} }
// we do as much in-place as possible to save allocation // we do as much in-place as possible to save allocation
payload := buf[frameHeaderLength : frameHeaderLength+payloadLen+padLen] useful := buf[:usefulLen] // (tls header) + payload + potential overhead
if payloadOffsetInBuf != frameHeaderLength { header := useful[rlLen : rlLen+HEADER_LEN]
// if payload is not at the correct location in buffer encryptedPayloadWithExtra := useful[rlLen+HEADER_LEN:]
copy(payload, f.Payload)
}
header := buf[:frameHeaderLength] putU32(header[0:4], f.StreamID)
binary.BigEndian.PutUint32(header[0:4], f.StreamID) putU64(header[4:12], f.Seq)
binary.BigEndian.PutUint64(header[4:12], f.Seq)
header[12] = f.Closing header[12] = f.Closing
header[13] = byte(padLen + tagLen) header[13] = extraLen
// Random bytes for padding and nonce if payloadCipher == nil {
_, err := rand.Read(buf[frameHeaderLength+payloadLen : usefulLen]) copy(encryptedPayloadWithExtra, f.Payload)
if err != nil { if extraLen != 0 {
return 0, fmt.Errorf("failed to pad random: %w", err) rand.Read(encryptedPayloadWithExtra[len(encryptedPayloadWithExtra)-int(extraLen):])
}
} else {
ciphertext := payloadCipher.Seal(nil, header[:12], f.Payload, nil)
copy(encryptedPayloadWithExtra, ciphertext)
} }
if o.payloadCipher != nil { nonce := encryptedPayloadWithExtra[len(encryptedPayloadWithExtra)-8:]
o.payloadCipher.Seal(payload[:0], header[:o.payloadCipher.NonceSize()], payload, nil) salsa20.XORKeyStream(header, header, nonce, &salsaKey)
if hasRecordLayer {
recordLayer := useful[0:5]
// We don't use util.AddRecordLayer here to avoid unnecessary malloc
recordLayer[0] = 0x17
recordLayer[1] = 0x03
recordLayer[2] = 0x03
binary.BigEndian.PutUint16(recordLayer[3:5], uint16(HEADER_LEN+len(encryptedPayloadWithExtra)))
} }
// Composing final obfsed message
nonce := buf[usefulLen-salsa20NonceSize : usefulLen]
salsa20.XORKeyStream(header, header, nonce, &o.sessionKey)
return usefulLen, nil return usefulLen, nil
}
return obfs
} }
// deobfuscate removes TLS header, decrypt and unmarshall frames func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD, hasRecordLayer bool) Deobfser {
func (o *Obfuscator) deobfuscate(f *Frame, in []byte) error { var rlLen int
if len(in) < frameHeaderLength+salsa20NonceSize { if hasRecordLayer {
return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), frameHeaderLength+salsa20NonceSize) rlLen = 5
}
deobfs := func(in []byte) (*Frame, error) {
if len(in) < rlLen+HEADER_LEN+8 {
return nil, fmt.Errorf("Input cannot be shorter than %v bytes", rlLen+HEADER_LEN+8)
} }
header := in[:frameHeaderLength] peeled := make([]byte, len(in)-rlLen)
pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead copy(peeled, in[rlLen:])
nonce := in[len(in)-salsa20NonceSize:] header := peeled[:HEADER_LEN]
salsa20.XORKeyStream(header, header, nonce, &o.sessionKey) pldWithOverHead := peeled[HEADER_LEN:] // payload + potential overhead
streamID := binary.BigEndian.Uint32(header[0:4]) nonce := peeled[len(peeled)-8:]
seq := binary.BigEndian.Uint64(header[4:12]) salsa20.XORKeyStream(header, header, nonce, &salsaKey)
streamID := u32(header[0:4])
seq := u64(header[4:12])
closing := header[12] closing := header[12]
extraLen := header[13] extraLen := header[13]
usefulPayloadLen := len(pldWithOverHead) - int(extraLen) usefulPayloadLen := len(pldWithOverHead) - int(extraLen)
if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { if usefulPayloadLen < 0 {
return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") return nil, errors.New("extra length is greater than total pldWithOverHead length")
} }
var outputPayload []byte var outputPayload []byte
if o.payloadCipher == nil { if payloadCipher == nil {
if extraLen == 0 { if extraLen == 0 {
outputPayload = pldWithOverHead outputPayload = pldWithOverHead
} else { } else {
outputPayload = pldWithOverHead[:usefulPayloadLen] outputPayload = pldWithOverHead[:usefulPayloadLen]
} }
} else { } else {
_, err := o.payloadCipher.Open(pldWithOverHead[:0], header[:o.payloadCipher.NonceSize()], pldWithOverHead, nil) _, err := payloadCipher.Open(pldWithOverHead[:0], header[:12], pldWithOverHead, nil)
if err != nil { if err != nil {
return err return nil, err
} }
outputPayload = pldWithOverHead[:usefulPayloadLen] outputPayload = pldWithOverHead[:usefulPayloadLen]
} }
f.StreamID = streamID ret := &Frame{
f.Seq = seq StreamID: streamID,
f.Closing = closing Seq: seq,
f.Payload = outputPayload Closing: closing,
return nil Payload: outputPayload,
}
return ret, nil
}
return deobfs
} }
func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (o Obfuscator, err error) { func GenerateObfs(encryptionMethod byte, sessionKey []byte, hasRecordLayer bool) (obfuscator *Obfuscator, err error) {
o = Obfuscator{ if len(sessionKey) != 32 {
sessionKey: sessionKey, err = errors.New("sessionKey size must be 32 bytes")
} }
var salsaKey [32]byte
copy(salsaKey[:], sessionKey)
var payloadCipher cipher.AEAD
switch encryptionMethod { switch encryptionMethod {
case EncryptionMethodPlain: case E_METHOD_PLAIN:
o.payloadCipher = nil payloadCipher = nil
case EncryptionMethodAES256GCM: case E_METHOD_AES_GCM:
var c cipher.Block var c cipher.Block
c, err = aes.NewCipher(sessionKey[:]) c, err = aes.NewCipher(sessionKey)
if err != nil { if err != nil {
return return
} }
o.payloadCipher, err = cipher.NewGCM(c) payloadCipher, err = cipher.NewGCM(c)
if err != nil { if err != nil {
return return
} }
case EncryptionMethodAES128GCM: case E_METHOD_CHACHA20_POLY1305:
var c cipher.Block payloadCipher, err = chacha20poly1305.New(sessionKey)
c, err = aes.NewCipher(sessionKey[:16])
if err != nil {
return
}
o.payloadCipher, err = cipher.NewGCM(c)
if err != nil {
return
}
case EncryptionMethodChaha20Poly1305:
o.payloadCipher, err = chacha20poly1305.New(sessionKey[:])
if err != nil { if err != nil {
return return
} }
default: default:
return o, fmt.Errorf("unknown encryption method valued %v", encryptionMethod) return nil, errors.New("Unknown encryption method")
} }
if o.payloadCipher != nil { obfuscator = &Obfuscator{
if o.payloadCipher.NonceSize() > frameHeaderLength { MakeObfs(salsaKey, payloadCipher, hasRecordLayer),
return o, errors.New("payload AEAD's nonce size cannot be greater than size of frame header") MakeDeobfs(salsaKey, payloadCipher, hasRecordLayer),
sessionKey,
} }
}
return return
} }

View File

@ -1,128 +1,95 @@
package multiplex package multiplex
import ( import (
"bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"golang.org/x/crypto/chacha20poly1305"
"math/rand" "math/rand"
"reflect" "reflect"
"testing" "testing"
"testing/quick" "testing/quick"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/chacha20poly1305"
) )
func TestGenerateObfs(t *testing.T) { func TestGenerateObfs(t *testing.T) {
var sessionKey [32]byte sessionKey := make([]byte, 32)
rand.Read(sessionKey[:]) rand.Read(sessionKey)
run := func(o Obfuscator, t *testing.T) { run := func(obfuscator *Obfuscator, ct *testing.T) {
obfsBuf := make([]byte, 512) obfsBuf := make([]byte, 512)
_testFrame, _ := quick.Value(reflect.TypeOf(Frame{}), rand.New(rand.NewSource(42))) f := &Frame{}
testFrame := _testFrame.Interface().(Frame) _testFrame, _ := quick.Value(reflect.TypeOf(f), rand.New(rand.NewSource(42)))
i, err := o.obfuscate(&testFrame, obfsBuf, 0) testFrame := _testFrame.Interface().(*Frame)
assert.NoError(t, err) i, err := obfuscator.Obfs(testFrame, obfsBuf)
var resultFrame Frame if err != nil {
ct.Error("failed to obfs ", err)
return
}
err = o.deobfuscate(&resultFrame, obfsBuf[:i]) resultFrame, err := obfuscator.Deobfs(obfsBuf[:i])
assert.NoError(t, err) if err != nil {
assert.EqualValues(t, testFrame, resultFrame) ct.Error("failed to deobfs ", err)
return
}
if !bytes.Equal(testFrame.Payload, resultFrame.Payload) || testFrame.StreamID != resultFrame.StreamID {
ct.Error("expecting", testFrame,
"got", resultFrame)
return
}
} }
t.Run("plain", func(t *testing.T) { t.Run("plain", func(t *testing.T) {
o, err := MakeObfuscator(EncryptionMethodPlain, sessionKey) obfuscator, err := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
assert.NoError(t, err) if err != nil {
run(o, t) t.Errorf("failed to generate obfuscator %v", err)
} else {
run(obfuscator, t)
}
}) })
t.Run("aes-256-gcm", func(t *testing.T) { t.Run("plain no record layer", func(t *testing.T) {
o, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey) obfuscator, err := GenerateObfs(E_METHOD_PLAIN, sessionKey, false)
assert.NoError(t, err) if err != nil {
run(o, t) t.Errorf("failed to generate obfuscator %v", err)
} else {
run(obfuscator, t)
}
}) })
t.Run("aes-128-gcm", func(t *testing.T) { t.Run("aes-gcm", func(t *testing.T) {
o, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey) obfuscator, err := GenerateObfs(E_METHOD_AES_GCM, sessionKey, true)
assert.NoError(t, err) if err != nil {
run(o, t) t.Errorf("failed to generate obfuscator %v", err)
} else {
run(obfuscator, t)
}
})
t.Run("aes-gcm no record layer", func(t *testing.T) {
obfuscator, err := GenerateObfs(E_METHOD_AES_GCM, sessionKey, false)
if err != nil {
t.Errorf("failed to generate obfuscator %v", err)
} else {
run(obfuscator, t)
}
}) })
t.Run("chacha20-poly1305", func(t *testing.T) { t.Run("chacha20-poly1305", func(t *testing.T) {
o, err := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) obfuscator, err := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey, true)
assert.NoError(t, err) if err != nil {
run(o, t) t.Errorf("failed to generate obfuscator %v", err)
} else {
run(obfuscator, t)
}
}) })
t.Run("unknown encryption method", func(t *testing.T) { t.Run("unknown encryption method", func(t *testing.T) {
_, err := MakeObfuscator(0xff, sessionKey) _, err := GenerateObfs(0xff, sessionKey, true)
assert.Error(t, err) if err == nil {
t.Errorf("unknown encryption mehtod error expected")
}
}) })
} t.Run("bad key length", func(t *testing.T) {
_, err := GenerateObfs(0xff, sessionKey[:31], true)
func TestObfuscate(t *testing.T) { if err == nil {
var sessionKey [32]byte t.Errorf("bad key length error expected")
rand.Read(sessionKey[:])
const testPayloadLen = 1024
testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
f := Frame{
StreamID: 0,
Seq: 0,
Closing: 0,
Payload: testPayload,
} }
runTest := func(t *testing.T, o Obfuscator) {
obfsBuf := make([]byte, testPayloadLen*2)
n, err := o.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
resultFrame := Frame{}
err = o.deobfuscate(&resultFrame, obfsBuf[:n])
assert.NoError(t, err)
assert.EqualValues(t, f, resultFrame)
}
t.Run("plain", func(t *testing.T) {
o := Obfuscator{
payloadCipher: nil,
sessionKey: sessionKey,
}
runTest(t, o)
}) })
t.Run("aes-128-gcm", func(t *testing.T) {
c, err := aes.NewCipher(sessionKey[:16])
assert.NoError(t, err)
payloadCipher, err := cipher.NewGCM(c)
assert.NoError(t, err)
o := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: sessionKey,
}
runTest(t, o)
})
t.Run("aes-256-gcm", func(t *testing.T) {
c, err := aes.NewCipher(sessionKey[:])
assert.NoError(t, err)
payloadCipher, err := cipher.NewGCM(c)
assert.NoError(t, err)
o := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: sessionKey,
}
runTest(t, o)
})
t.Run("chacha20-poly1305", func(t *testing.T) {
payloadCipher, err := chacha20poly1305.New(sessionKey[:])
assert.NoError(t, err)
o := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: sessionKey,
}
runTest(t, o)
})
} }
func BenchmarkObfs(b *testing.B) { func BenchmarkObfs(b *testing.B) {
@ -135,7 +102,7 @@ func BenchmarkObfs(b *testing.B) {
testPayload, testPayload,
} }
obfsBuf := make([]byte, len(testPayload)*2) obfsBuf := make([]byte, 2048)
var key [32]byte var key [32]byte
rand.Read(key[:]) rand.Read(key[:])
@ -143,53 +110,56 @@ func BenchmarkObfs(b *testing.B) {
c, _ := aes.NewCipher(key[:]) c, _ := aes.NewCipher(key[:])
payloadCipher, _ := cipher.NewGCM(c) payloadCipher, _ := cipher.NewGCM(c)
obfuscator := Obfuscator{ obfs := MakeObfs(key, payloadCipher, true)
payloadCipher: payloadCipher,
sessionKey: key,
}
b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfuscator.obfuscate(testFrame, obfsBuf, 0) n, err := obfs(testFrame, obfsBuf)
if err != nil {
b.Error(err)
return
}
b.SetBytes(int64(n))
} }
}) })
b.Run("AES128GCM", func(b *testing.B) { b.Run("AES128GCM", func(b *testing.B) {
c, _ := aes.NewCipher(key[:16]) c, _ := aes.NewCipher(key[:16])
payloadCipher, _ := cipher.NewGCM(c) payloadCipher, _ := cipher.NewGCM(c)
obfuscator := Obfuscator{ obfs := MakeObfs(key, payloadCipher, true)
payloadCipher: payloadCipher,
sessionKey: key,
}
b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfuscator.obfuscate(testFrame, obfsBuf, 0) n, err := obfs(testFrame, obfsBuf)
if err != nil {
b.Error(err)
return
}
b.SetBytes(int64(n))
} }
}) })
b.Run("plain", func(b *testing.B) { b.Run("plain", func(b *testing.B) {
obfuscator := Obfuscator{ obfs := MakeObfs(key, nil, true)
payloadCipher: nil,
sessionKey: key,
}
b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfuscator.obfuscate(testFrame, obfsBuf, 0) n, err := obfs(testFrame, obfsBuf)
if err != nil {
b.Error(err)
return
}
b.SetBytes(int64(n))
} }
}) })
b.Run("chacha20Poly1305", func(b *testing.B) { b.Run("chacha20Poly1305", func(b *testing.B) {
payloadCipher, _ := chacha20poly1305.New(key[:]) payloadCipher, _ := chacha20poly1305.New(key[:16])
obfuscator := Obfuscator{ obfs := MakeObfs(key, payloadCipher, true)
payloadCipher: payloadCipher,
sessionKey: key,
}
b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfuscator.obfuscate(testFrame, obfsBuf, 0) n, err := obfs(testFrame, obfsBuf)
if err != nil {
b.Error(err)
return
}
b.SetBytes(int64(n))
} }
}) })
} }
@ -204,73 +174,76 @@ func BenchmarkDeobfs(b *testing.B) {
testPayload, testPayload,
} }
obfsBuf := make([]byte, len(testPayload)*2) obfsBuf := make([]byte, 2048)
var key [32]byte var key [32]byte
rand.Read(key[:]) rand.Read(key[:])
b.Run("AES256GCM", func(b *testing.B) { b.Run("AES256GCM", func(b *testing.B) {
c, _ := aes.NewCipher(key[:]) c, _ := aes.NewCipher(key[:])
payloadCipher, _ := cipher.NewGCM(c) payloadCipher, _ := cipher.NewGCM(c)
obfuscator := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: key,
}
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0) obfs := MakeObfs(key, payloadCipher, true)
n, _ := obfs(testFrame, obfsBuf)
deobfs := MakeDeobfs(key, payloadCipher, true)
frame := new(Frame)
b.SetBytes(int64(n))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfuscator.deobfuscate(frame, obfsBuf[:n]) _, err := deobfs(obfsBuf[:n])
if err != nil {
b.Error(err)
return
}
b.SetBytes(int64(n))
} }
}) })
b.Run("AES128GCM", func(b *testing.B) { b.Run("AES128GCM", func(b *testing.B) {
c, _ := aes.NewCipher(key[:16]) c, _ := aes.NewCipher(key[:16])
payloadCipher, _ := cipher.NewGCM(c) payloadCipher, _ := cipher.NewGCM(c)
obfuscator := Obfuscator{ obfs := MakeObfs(key, payloadCipher, true)
payloadCipher: payloadCipher, n, _ := obfs(testFrame, obfsBuf)
sessionKey: key, deobfs := MakeDeobfs(key, payloadCipher, true)
}
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfuscator.deobfuscate(frame, obfsBuf[:n]) _, err := deobfs(obfsBuf[:n])
if err != nil {
b.Error(err)
return
}
b.SetBytes(int64(n))
} }
}) })
b.Run("plain", func(b *testing.B) { b.Run("plain", func(b *testing.B) {
obfuscator := Obfuscator{ obfs := MakeObfs(key, nil, true)
payloadCipher: nil, n, _ := obfs(testFrame, obfsBuf)
sessionKey: key, deobfs := MakeDeobfs(key, nil, true)
}
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfuscator.deobfuscate(frame, obfsBuf[:n]) _, err := deobfs(obfsBuf[:n])
if err != nil {
b.Error(err)
return
}
b.SetBytes(int64(n))
} }
}) })
b.Run("chacha20Poly1305", func(b *testing.B) { b.Run("chacha20Poly1305", func(b *testing.B) {
payloadCipher, _ := chacha20poly1305.New(key[:]) payloadCipher, _ := chacha20poly1305.New(key[:16])
obfuscator := Obfuscator{ obfs := MakeObfs(key, payloadCipher, true)
payloadCipher: payloadCipher, n, _ := obfs(testFrame, obfsBuf)
sessionKey: key, deobfs := MakeDeobfs(key, payloadCipher, true)
}
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
obfuscator.deobfuscate(frame, obfsBuf[:n]) _, err := deobfs(obfsBuf[:n])
if err != nil {
b.Error(err)
return
}
b.SetBytes(int64(n))
} }
}) })
} }

View File

@ -7,8 +7,9 @@ import (
) )
// Valve needs to be universal, across all sessions that belong to a user // Valve needs to be universal, across all sessions that belong to a user
// gabe please don't sue
type LimitedValve struct { type LimitedValve struct {
// traffic directions from the server's perspective are referred // traffic directions from the server's perspective are refered
// exclusively as rx and tx. // exclusively as rx and tx.
// rx is from client to server, tx is from server to client // rx is from client to server, tx is from server to client
// DO NOT use terms up or down as this is used in usermanager // DO NOT use terms up or down as this is used in usermanager

View File

@ -1,24 +1,8 @@
package multiplex package multiplex
import ( import "io"
"errors"
"io"
"time"
)
var ErrTimeout = errors.New("deadline exceeded")
type recvBuffer interface { type recvBuffer interface {
// Read calls' err must be nil | io.EOF | io.ErrShortBuffer
// Read should NOT return error on a closed streamBuffer with a non-empty buffer.
// Instead, it should behave as if it hasn't been closed. Closure is only relevant
// when the buffer is empty.
io.ReadCloser io.ReadCloser
Write(*Frame) (toBeClosed bool, err error) Write(Frame) error
SetReadDeadline(time time.Time)
} }
// size we want the amount of unread data in buffer to grow before recvBuffer.Write blocks.
// If the buffer grows larger than what the system's memory can offer at the time of recvBuffer.Write,
// a panic will happen.
const recvBufferSizeLimit = 1<<31 - 1

View File

@ -8,163 +8,116 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/cbeuw/Cloak/internal/common"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const ( const (
acceptBacklog = 1024 acceptBacklog = 1024
defaultInactivityTimeout = 30 * time.Second
defaultMaxOnWireSize = 1<<14 + 256 // https://tools.ietf.org/html/rfc8446#section-5.2
) )
var ErrBrokenSession = errors.New("broken session") var ErrBrokenSession = errors.New("broken session")
var errRepeatSessionClosing = errors.New("trying to close a closed session") var errRepeatSessionClosing = errors.New("trying to close a closed session")
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
type SessionConfig struct { // Obfuscator is responsible for the obfuscation and deobfuscation of frames
Obfuscator type Obfuscator struct {
// Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header
// Valve is used to limit transmission rates, and record and limit usage Obfs Obfser
Valve // Remove TLS header, decrypt and unmarshall frames
Deobfs Deobfser
Unordered bool SessionKey []byte
}
// A Singleplexing session always has just one stream
Singleplex bool type switchboardStrategy int
// maximum size of an obfuscated frame, including headers and overhead type SessionConfig struct {
MsgOnWireSizeLimit int NoRecordLayer bool
// InactivityTimeout sets the duration a Session waits while it has no active streams before it closes itself *Obfuscator
InactivityTimeout time.Duration
Valve
// This is supposed to read one TLS message.
UnitRead func(net.Conn, []byte) (int, error)
Unordered bool
} }
// A Session represents a self-contained communication chain between local and remote. It manages its streams,
// controls serialisation and encryption of data sent and received using the supplied Obfuscator, and send and receive
// data through a manged connection pool filled with underlying connections added to it.
type Session struct { type Session struct {
id uint32 id uint32
SessionConfig *SessionConfig
// atomic // atomic
nextStreamID uint32 nextStreamID uint32
// atomic
activeStreamCount uint32
streamsM sync.Mutex streamsM sync.Mutex
streams map[uint32]*Stream streams map[uint32]*Stream
// For accepting new streams
acceptCh chan *Stream
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
recvFramePool sync.Pool
streamObfsBufPool sync.Pool
// Switchboard manages all connections to remote // Switchboard manages all connections to remote
sb *switchboard sb *switchboard
// Used for LocalAddr() and RemoteAddr() etc.
addrs atomic.Value addrs atomic.Value
// For accepting new streams
acceptCh chan *Stream
closed uint32 closed uint32
terminalMsgSetter sync.Once terminalMsg atomic.Value
terminalMsg string
// the max size passed to Write calls before it splits it into multiple frames
// i.e. the max size a piece of data can fit into a Frame.Payload
maxStreamUnitWrite int
// streamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf)
streamSendBufferSize int
// connReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in
// switchboard.deplex)
connReceiveBufferSize int
} }
func MakeSession(id uint32, config SessionConfig) *Session { func MakeSession(id uint32, config *SessionConfig) *Session {
sesh := &Session{ sesh := &Session{
id: id, id: id,
SessionConfig: config, SessionConfig: config,
nextStreamID: 1, nextStreamID: 1,
streams: make(map[uint32]*Stream),
acceptCh: make(chan *Stream, acceptBacklog), acceptCh: make(chan *Stream, acceptBacklog),
recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }},
streams: map[uint32]*Stream{},
} }
sesh.addrs.Store([]net.Addr{nil, nil}) sesh.addrs.Store([]net.Addr{nil, nil})
if config.Valve == nil { if config.Valve == nil {
sesh.Valve = UNLIMITED_VALVE config.Valve = UNLIMITED_VALVE
}
if config.MsgOnWireSizeLimit <= 0 {
sesh.MsgOnWireSizeLimit = defaultMaxOnWireSize
}
if config.InactivityTimeout == 0 {
sesh.InactivityTimeout = defaultInactivityTimeout
} }
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - maxExtraLen sbConfig := &switchboardConfig{
sesh.streamSendBufferSize = sesh.MsgOnWireSizeLimit Valve: config.Valve,
sesh.connReceiveBufferSize = 20480 // for backwards compatibility }
if config.Unordered {
sesh.streamObfsBufPool = sync.Pool{New: func() interface{} { log.Debug("Connection is unordered")
b := make([]byte, sesh.streamSendBufferSize) sbConfig.strategy = UNIFORM_SPREAD
return &b } else {
}} sbConfig.strategy = FIXED_CONN_MAPPING
}
sesh.sb = makeSwitchboard(sesh) sesh.sb = makeSwitchboard(sesh, sbConfig)
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout) go sesh.timeoutAfter(30 * time.Second)
return sesh return sesh
} }
func (sesh *Session) GetSessionKey() [32]byte {
return sesh.sessionKey
}
func (sesh *Session) streamCountIncr() uint32 {
return atomic.AddUint32(&sesh.activeStreamCount, 1)
}
func (sesh *Session) streamCountDecr() uint32 {
return atomic.AddUint32(&sesh.activeStreamCount, ^uint32(0))
}
func (sesh *Session) streamCount() uint32 {
return atomic.LoadUint32(&sesh.activeStreamCount)
}
// AddConnection is used to add an underlying connection to the connection pool
func (sesh *Session) AddConnection(conn net.Conn) { func (sesh *Session) AddConnection(conn net.Conn) {
sesh.sb.addConn(conn) sesh.sb.addConn(conn)
addrs := []net.Addr{conn.LocalAddr(), conn.RemoteAddr()} addrs := []net.Addr{conn.LocalAddr(), conn.RemoteAddr()}
sesh.addrs.Store(addrs) sesh.addrs.Store(addrs)
} }
// OpenStream is similar to net.Dial. It opens up a new stream
func (sesh *Session) OpenStream() (*Stream, error) { func (sesh *Session) OpenStream() (*Stream, error) {
if sesh.IsClosed() { if sesh.IsClosed() {
return nil, ErrBrokenSession return nil, ErrBrokenSession
} }
id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1 id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1
// Because atomic.AddUint32 returns the value after incrementation // Because atomic.AddUint32 returns the value after incrementation
if sesh.Singleplex && id > 1 { connId, err := sesh.sb.assignRandomConn()
// if there are more than one streams, which shouldn't happen if we are if err != nil {
// singleplexing return nil, err
return nil, errNoMultiplex
} }
stream := makeStream(sesh, id) stream := makeStream(sesh, id, connId)
sesh.streamsM.Lock() sesh.streamsM.Lock()
sesh.streams[id] = stream sesh.streams[id] = stream
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
sesh.streamCountIncr()
log.Tracef("stream %v of session %v opened", id, sesh.id) log.Tracef("stream %v of session %v opened", id, sesh.id)
return stream, nil return stream, nil
} }
// Accept is similar to net.Listener's Accept(). It blocks and returns an incoming stream
func (sesh *Session) Accept() (net.Conn, error) { func (sesh *Session) Accept() (net.Conn, error) {
if sesh.IsClosed() { if sesh.IsClosed() {
return nil, ErrBrokenSession return nil, ErrBrokenSession
@ -173,179 +126,103 @@ func (sesh *Session) Accept() (net.Conn, error) {
if stream == nil { if stream == nil {
return nil, ErrBrokenSession return nil, ErrBrokenSession
} }
sesh.streamsM.Lock()
sesh.streams[stream.id] = stream
sesh.streamsM.Unlock()
log.Tracef("stream %v of session %v accepted", stream.id, sesh.id) log.Tracef("stream %v of session %v accepted", stream.id, sesh.id)
return stream, nil return stream, nil
} }
func (sesh *Session) closeStream(s *Stream, active bool) error { func (sesh *Session) delStream(id uint32) {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
}
_ = s.recvBuf.Close() // recvBuf.Close should not return error
if active {
tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte)
// Notify remote that this stream is closed
common.CryptoRandRead((*tmpBuf)[:1])
padLen := int((*tmpBuf)[0]) + 1
payload := (*tmpBuf)[frameHeaderLength : padLen+frameHeaderLength]
common.CryptoRandRead(payload)
// must be holding s.wirtingM on entry
s.writingFrame.Closing = closingStream
s.writingFrame.Payload = payload
err := s.obfuscateAndSend(*tmpBuf, frameHeaderLength)
sesh.streamObfsBufPool.Put(tmpBuf)
if err != nil {
return err
}
log.Tracef("stream %v actively closed.", s.id)
} else {
log.Tracef("stream %v passively closed", s.id)
}
// We set it as nil to signify that the stream id had existed before.
// If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell
// if the frame it received was from a new stream or a dying stream whose frame arrived late
sesh.streamsM.Lock() sesh.streamsM.Lock()
sesh.streams[s.id] = nil delete(sesh.streams, id)
if len(sesh.streams) == 0 {
log.Tracef("session %v has no active stream left", sesh.id)
go sesh.timeoutAfter(30 * time.Second)
}
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
if sesh.streamCountDecr() == 0 {
if sesh.Singleplex {
return sesh.Close()
} else {
log.Debugf("session %v has no active stream left", sesh.id)
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout)
}
}
return nil
} }
// recvDataFromRemote deobfuscate the frame and read the Closing field. If it is a closing frame, it writes the frame
// to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new
// stream and then writes to the stream buffer
func (sesh *Session) recvDataFromRemote(data []byte) error { func (sesh *Session) recvDataFromRemote(data []byte) error {
frame := sesh.recvFramePool.Get().(*Frame) frame, err := sesh.Deobfs(data)
defer sesh.recvFramePool.Put(frame)
err := sesh.deobfuscate(frame, data)
if err != nil { if err != nil {
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
} }
if frame.Closing == closingSession { sesh.streamsM.Lock()
sesh.SetTerminalMsg("Received a closing notification frame") defer sesh.streamsM.Unlock()
return sesh.passiveClose() stream, existing := sesh.streams[frame.StreamID]
if existing {
return stream.writeFrame(*frame)
} else {
if frame.Closing == 1 {
// If the stream has been closed and the current frame is a closing frame, we do noop
return nil
} else {
// it may be tempting to use the connId from which the frame was received. However it doesn't make
// any difference because we only care to send the data from the same stream through the same
// TCP connection. The remote may use a different connection to send the same stream than the one the client
// use to send.
connId, _ := sesh.sb.assignRandomConn()
// we ignore the error here. If the switchboard is broken, it will be reflected upon stream.Write
stream = makeStream(sesh, frame.StreamID, connId)
sesh.acceptCh <- stream
return stream.writeFrame(*frame)
}
} }
sesh.streamsM.Lock()
if sesh.IsClosed() {
sesh.streamsM.Unlock()
return ErrBrokenSession
}
existingStream, existing := sesh.streams[frame.StreamID]
if existing {
sesh.streamsM.Unlock()
if existingStream == nil {
// this is when the stream existed before but has since been closed. We do nothing
return nil
}
return existingStream.recvFrame(frame)
} else {
newStream := makeStream(sesh, frame.StreamID)
sesh.streams[frame.StreamID] = newStream
sesh.acceptCh <- newStream
sesh.streamsM.Unlock()
// new stream
sesh.streamCountIncr()
return newStream.recvFrame(frame)
}
} }
func (sesh *Session) SetTerminalMsg(msg string) { func (sesh *Session) SetTerminalMsg(msg string) {
log.Debug("terminal message set to " + msg) sesh.terminalMsg.Store(msg)
sesh.terminalMsgSetter.Do(func() {
sesh.terminalMsg = msg
})
} }
func (sesh *Session) TerminalMsg() string { func (sesh *Session) TerminalMsg() string {
return sesh.terminalMsg msg := sesh.terminalMsg.Load()
} if msg != nil {
return msg.(string)
func (sesh *Session) closeSession() error { } else {
if !atomic.CompareAndSwapUint32(&sesh.closed, 0, 1) { return ""
log.Debugf("session %v has already been closed", sesh.id)
return errRepeatSessionClosing
} }
sesh.streamsM.Lock()
close(sesh.acceptCh)
for id, stream := range sesh.streams {
if stream != nil && atomic.CompareAndSwapUint32(&stream.closed, 0, 1) {
_ = stream.recvBuf.Close() // will not block
delete(sesh.streams, id)
sesh.streamCountDecr()
}
}
sesh.streamsM.Unlock()
return nil
}
func (sesh *Session) passiveClose() error {
log.Debugf("attempting to passively close session %v", sesh.id)
err := sesh.closeSession()
if err != nil {
return err
}
sesh.sb.closeAll()
log.Debugf("session %v closed gracefully", sesh.id)
return nil
} }
func (sesh *Session) Close() error { func (sesh *Session) Close() error {
log.Debugf("attempting to actively close session %v", sesh.id) log.Debugf("attempting to close session %v", sesh.id)
err := sesh.closeSession() if atomic.SwapUint32(&sesh.closed, 1) == 1 {
if err != nil { log.Debugf("session %v has already been closed", sesh.id)
return err return errRepeatSessionClosing
} }
// we send a notice frame telling remote to close the session sesh.streamsM.Lock()
sesh.acceptCh <- nil
for id, stream := range sesh.streams {
// If we call stream.Close() here, streamsM will result in a deadlock
// because stream.Close calls sesh.delStream, which locks the mutex.
// so we need to implement a method of stream that closes the stream without calling
// sesh.delStream
go stream.closeNoDelMap()
delete(sesh.streams, id)
}
sesh.streamsM.Unlock()
buf := sesh.streamObfsBufPool.Get().(*[]byte)
common.CryptoRandRead((*buf)[:1])
padLen := int((*buf)[0]) + 1
payload := (*buf)[frameHeaderLength : padLen+frameHeaderLength]
common.CryptoRandRead(payload)
f := &Frame{
StreamID: 0xffffffff,
Seq: 0,
Closing: closingSession,
Payload: payload,
}
i, err := sesh.obfuscate(f, *buf, frameHeaderLength)
if err != nil {
return err
}
_, err = sesh.sb.send((*buf)[:i], new(net.Conn))
if err != nil {
return err
}
sesh.sb.closeAll() sesh.sb.closeAll()
log.Debugf("session %v closed gracefully", sesh.id) log.Debugf("session %v closed gracefully", sesh.id)
return nil return nil
} }
func (sesh *Session) IsClosed() bool { func (sesh *Session) IsClosed() bool {
return atomic.LoadUint32(&sesh.closed) == 1 return atomic.LoadUint32(&sesh.closed) == 1
} }
func (sesh *Session) checkTimeout() { func (sesh *Session) timeoutAfter(to time.Duration) {
if sesh.streamCount() == 0 && !sesh.IsClosed() { time.Sleep(to)
sesh.streamsM.Lock()
if len(sesh.streams) == 0 && !sesh.IsClosed() {
sesh.streamsM.Unlock()
sesh.SetTerminalMsg("timeout") sesh.SetTerminalMsg("timeout")
sesh.Close() sesh.Close()
} else {
sesh.streamsM.Unlock()
} }
} }

View File

@ -1,24 +0,0 @@
//go:build gofuzz
// +build gofuzz
package multiplex
func setupSesh_fuzz(unordered bool) *Session {
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, [32]byte{})
seshConfig := SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
Unordered: unordered,
}
return MakeSession(0, seshConfig)
}
func Fuzz(data []byte) int {
sesh := setupSesh_fuzz(false)
err := sesh.recvDataFromRemote(data)
if err == nil {
return 1
}
return 0
}

View File

@ -2,639 +2,186 @@ package multiplex
import ( import (
"bytes" "bytes"
"io" "github.com/cbeuw/Cloak/internal/util"
"io/ioutil"
"math/rand" "math/rand"
"net"
"strconv"
"sync"
"sync/atomic"
"testing" "testing"
"time"
"github.com/cbeuw/connutil"
"github.com/stretchr/testify/assert"
) )
var seshConfigs = map[string]SessionConfig{ var seshConfigOrdered = &SessionConfig{
"ordered": {}, Obfuscator: nil,
"unordered": {Unordered: true}, Valve: nil,
} UnitRead: util.ReadTLS,
var encryptionMethods = map[string]byte{
"plain": EncryptionMethodPlain,
"aes-256-gcm": EncryptionMethodAES256GCM,
"aes-128-gcm": EncryptionMethodAES128GCM,
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
} }
const testPayloadLen = 1024 var seshConfigUnordered = &SessionConfig{
const obfsBufLen = testPayloadLen * 2 Obfuscator: nil,
Valve: nil,
UnitRead: util.ReadTLS,
Unordered: true,
}
func TestRecvDataFromRemote(t *testing.T) { func TestRecvDataFromRemote(t *testing.T) {
var sessionKey [32]byte testPayloadLen := 1024
rand.Read(sessionKey[:]) testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
for seshType, seshConfig := range seshConfigs { f := &Frame{
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
var err error
seshConfig.Obfuscator, err = MakeObfuscator(EncryptionMethodPlain, sessionKey)
if err != nil {
t.Fatalf("failed to make obfuscator: %v", err)
}
t.Run("initial frame", func(t *testing.T) {
sesh := MakeSession(0, seshConfig)
obfsBuf := make([]byte, obfsBufLen)
f := Frame{
1, 1,
0, 0,
0, 0,
make([]byte, testPayloadLen), testPayload,
} }
rand.Read(f.Payload) obfsBuf := make([]byte, 17000)
n, err := sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err) sessionKey := make([]byte, 32)
err = sesh.recvDataFromRemote(obfsBuf[:n]) rand.Read(sessionKey)
assert.NoError(t, err) t.Run("plain ordered", func(t *testing.T) {
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
sesh.recvDataFromRemote(obfsBuf[:n])
stream, err := sesh.Accept() stream, err := sesh.Accept()
assert.NoError(t, err) if err != nil {
t.Error(err)
return
}
resultPayload := make([]byte, testPayloadLen) resultPayload := make([]byte, testPayloadLen)
_, err = stream.Read(resultPayload) _, err = stream.Read(resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, f.Payload, resultPayload)
})
t.Run("two frames in order", func(t *testing.T) {
sesh := MakeSession(0, seshConfig)
obfsBuf := make([]byte, obfsBufLen)
f := Frame{
1,
0,
0,
make([]byte, testPayloadLen),
}
rand.Read(f.Payload)
n, err := sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
stream, err := sesh.Accept()
assert.NoError(t, err)
resultPayload := make([]byte, testPayloadLen)
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, f.Payload, resultPayload)
f.Seq += 1
rand.Read(f.Payload)
n, err = sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, f.Payload, resultPayload)
})
t.Run("two frames in order", func(t *testing.T) {
sesh := MakeSession(0, seshConfig)
obfsBuf := make([]byte, obfsBufLen)
f := Frame{
1,
0,
0,
make([]byte, testPayloadLen),
}
rand.Read(f.Payload)
n, err := sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
stream, err := sesh.Accept()
assert.NoError(t, err)
resultPayload := make([]byte, testPayloadLen)
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, f.Payload, resultPayload)
f.Seq += 1
rand.Read(f.Payload)
n, err = sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, f.Payload, resultPayload)
})
if seshType == "ordered" {
t.Run("frames out of order", func(t *testing.T) {
sesh := MakeSession(0, seshConfig)
obfsBuf := make([]byte, obfsBufLen)
f := Frame{
1,
0,
0,
nil,
}
// First frame
seq0 := make([]byte, testPayloadLen)
rand.Read(seq0)
f.Seq = 0
f.Payload = seq0
n, err := sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
// Third frame
seq2 := make([]byte, testPayloadLen)
rand.Read(seq2)
f.Seq = 2
f.Payload = seq2
n, err = sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
// Second frame
seq1 := make([]byte, testPayloadLen)
rand.Read(seq1)
f.Seq = 1
f.Payload = seq1
n, err = sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
// Expect things to receive in order
stream, err := sesh.Accept()
assert.NoError(t, err)
resultPayload := make([]byte, testPayloadLen)
// First
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, seq0, resultPayload)
// Second
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, seq1, resultPayload)
// Third
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, seq2, resultPayload)
})
}
})
}
}
func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
obfsBuf := make([]byte, obfsBufLen)
var sessionKey [32]byte
rand.Read(sessionKey[:])
seshConfig := seshConfigs["ordered"]
seshConfig.Obfuscator, _ = MakeObfuscator(EncryptionMethodPlain, sessionKey)
sesh := MakeSession(0, seshConfig)
f1 := &Frame{
1,
0,
closingNothing,
testPayload,
}
// create stream 1
n, _ := sesh.obfuscate(f1, obfsBuf, 0)
err := sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving normal frame for stream 1: %v", err)
}
sesh.streamsM.Lock()
_, ok := sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if !ok {
t.Fatal("failed to fetch stream 1 after receiving it")
}
if sesh.streamCount() != 1 {
t.Error("stream count isn't 1")
}
// create stream 2
f2 := &Frame{
2,
0,
closingNothing,
testPayload,
}
n, _ = sesh.obfuscate(f2, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving normal frame for stream 2: %v", err)
}
sesh.streamsM.Lock()
s2M, ok := sesh.streams[f2.StreamID]
sesh.streamsM.Unlock()
if s2M == nil || !ok {
t.Fatal("failed to fetch stream 2 after receiving it")
}
if sesh.streamCount() != 2 {
t.Error("stream count isn't 2")
}
// close stream 1
f1CloseStream := &Frame{
1,
1,
closingStream,
testPayload,
}
n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
}
sesh.streamsM.Lock()
s1M, _ := sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if s1M != nil {
t.Fatal("stream 1 still exist after receiving stream close")
}
s1, _ := sesh.Accept()
if !s1.(*Stream).isClosed() {
t.Fatal("stream 1 not marked as closed")
}
payloadBuf := make([]byte, testPayloadLen)
_, err = s1.Read(payloadBuf)
if err != nil || !bytes.Equal(payloadBuf, testPayload) {
t.Fatalf("failed to read from stream 1 after closing: %v", err)
}
s2, _ := sesh.Accept()
if s2.(*Stream).isClosed() {
t.Fatal("stream 2 shouldn't be closed")
}
if sesh.streamCount() != 1 {
t.Error("stream count isn't 1 after stream 1 closed")
}
// close stream 1 again
n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
}
sesh.streamsM.Lock()
s1M, _ = sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if s1M != nil {
t.Error("stream 1 exists after receiving stream close for the second time")
}
streamCount := sesh.streamCount()
if streamCount != 1 {
t.Errorf("stream count is %v after stream 1 closed twice, expected 1", streamCount)
}
// close session
fCloseSession := &Frame{
StreamID: 0xffffffff,
Seq: 0,
Closing: closingSession,
Payload: testPayload,
}
n, _ = sesh.obfuscate(fCloseSession, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving session closing frame: %v", err)
}
if !sesh.IsClosed() {
t.Error("session not closed after receiving signal")
}
if !s2.(*Stream).isClosed() {
t.Error("stream 2 isn't closed after session closed")
}
if _, err := s2.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) {
t.Error("failed to read from stream 2 after session closed")
}
if _, err := s2.Write(testPayload); err == nil {
t.Error("can still write to stream 2 after session closed")
}
if sesh.streamCount() != 0 {
t.Error("stream count isn't 0 after session closed")
}
}
func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
// Tests for when the closing frame of a stream is received first before any data frame
testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
obfsBuf := make([]byte, obfsBufLen)
var sessionKey [32]byte
rand.Read(sessionKey[:])
seshConfig := seshConfigs["ordered"]
seshConfig.Obfuscator, _ = MakeObfuscator(EncryptionMethodPlain, sessionKey)
sesh := MakeSession(0, seshConfig)
// receive stream 1 closing first
f1CloseStream := &Frame{
1,
1,
closingStream,
testPayload,
}
n, _ := sesh.obfuscate(f1CloseStream, obfsBuf, 0)
err := sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
}
sesh.streamsM.Lock()
_, ok := sesh.streams[f1CloseStream.StreamID]
sesh.streamsM.Unlock()
if !ok {
t.Fatal("stream 1 doesn't exist")
}
if sesh.streamCount() != 1 {
t.Error("stream count isn't 1 after stream 1 received")
}
// receive data frame of stream 1 after receiving the closing frame
f1 := &Frame{
1,
0,
closingNothing,
testPayload,
}
n, _ = sesh.obfuscate(f1, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving normal frame for stream 1: %v", err)
}
s1, err := sesh.Accept()
if err != nil {
t.Fatal("failed to accept stream 1 after receiving it")
}
payloadBuf := make([]byte, testPayloadLen)
if _, err := s1.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) {
t.Error("failed to read from steam 1")
}
if !s1.(*Stream).isClosed() {
t.Error("s1 isn't closed")
}
if sesh.streamCount() != 0 {
t.Error("stream count isn't 0 after stream 1 closed")
}
}
func TestParallelStreams(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
numStreams := acceptBacklog
seqs := make([]*uint64, numStreams)
for i := range seqs {
seqs[i] = new(uint64)
}
randFrame := func() *Frame {
id := rand.Intn(numStreams)
return &Frame{
uint32(id),
atomic.AddUint64(seqs[id], 1) - 1,
uint8(rand.Intn(2)),
[]byte{1, 2, 3, 4},
}
}
const numOfTests = 5000
tests := make([]struct {
name string
frame *Frame
}, numOfTests)
for i := range tests {
tests[i].name = strconv.Itoa(i)
tests[i].frame = randFrame()
}
var wg sync.WaitGroup
for _, tc := range tests {
wg.Add(1)
go func(frame *Frame) {
obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.obfuscate(frame, obfsBuf, 0)
obfsBuf = obfsBuf[0:n]
err := sesh.recvDataFromRemote(obfsBuf)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
}
wg.Done()
}(tc.frame)
}
wg.Wait()
sc := int(sesh.streamCount())
var count int
sesh.streamsM.Lock()
for _, s := range sesh.streams {
if s != nil {
count++
}
}
sesh.streamsM.Unlock()
if sc != count {
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
}
})
}
}
func TestStream_SetReadDeadline(t *testing.T) {
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
sesh := MakeSession(0, seshConfig)
sesh.AddConnection(connutil.Discard())
t.Run("read after deadline set", func(t *testing.T) {
stream, _ := sesh.OpenStream()
_ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second))
_, err := stream.Read(make([]byte, 1))
if err != ErrTimeout {
t.Errorf("expecting error %v, got %v", ErrTimeout, err)
}
})
t.Run("unblock when deadline passed", func(t *testing.T) {
stream, _ := sesh.OpenStream()
done := make(chan struct{})
go func() {
_, _ = stream.Read(make([]byte, 1))
done <- struct{}{}
}()
_ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
select {
case <-done:
return return
case <-time.After(500 * time.Millisecond): }
t.Error("Read did not unblock after deadline has passed") if !bytes.Equal(testPayload, resultPayload) {
t.Errorf("Expecting %x, got %x", testPayload, resultPayload)
} }
}) })
}) t.Run("aes-gcm ordered", func(t *testing.T) {
obfuscator, _ := GenerateObfs(E_METHOD_AES_GCM, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
sesh.recvDataFromRemote(obfsBuf[:n])
stream, err := sesh.Accept()
if err != nil {
t.Error(err)
return
} }
resultPayload := make([]byte, testPayloadLen)
_, err = stream.Read(resultPayload)
if err != nil {
t.Error(err)
return
}
if !bytes.Equal(testPayload, resultPayload) {
t.Errorf("Expecting %x, got %x", testPayload, resultPayload)
}
})
t.Run("chacha20-poly1305 ordered", func(t *testing.T) {
obfuscator, _ := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
sesh.recvDataFromRemote(obfsBuf[:n])
stream, err := sesh.Accept()
if err != nil {
t.Error(err)
return
}
resultPayload := make([]byte, testPayloadLen)
_, err = stream.Read(resultPayload)
if err != nil {
t.Error(err)
return
}
if !bytes.Equal(testPayload, resultPayload) {
t.Errorf("Expecting %x, got %x", testPayload, resultPayload)
}
})
t.Run("plain unordered", func(t *testing.T) {
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
seshConfigUnordered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
sesh.recvDataFromRemote(obfsBuf[:n])
stream, err := sesh.Accept()
if err != nil {
t.Error(err)
return
}
resultPayload := make([]byte, testPayloadLen)
_, err = stream.Read(resultPayload)
if err != nil {
t.Error(err)
return
}
if !bytes.Equal(testPayload, resultPayload) {
t.Errorf("Expecting %x, got %x", testPayload, resultPayload)
}
})
} }
func TestSession_timeoutAfter(t *testing.T) { func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
var sessionKey [32]byte testPayloadLen := 1024
rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
seshConfig.Obfuscator = obfuscator
seshConfig.InactivityTimeout = 100 * time.Millisecond
sesh := MakeSession(0, seshConfig)
assert.Eventually(t, func() bool {
return sesh.IsClosed()
}, 5*seshConfig.InactivityTimeout, seshConfig.InactivityTimeout, "session should have timed out")
})
}
}
func BenchmarkRecvDataFromRemote(b *testing.B) {
testPayload := make([]byte, testPayloadLen) testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload) rand.Read(testPayload)
f := Frame{ f := &Frame{
1, 1,
0, 0,
0, 0,
testPayload, testPayload,
} }
obfsBuf := make([]byte, 17000)
var sessionKey [32]byte sessionKey := make([]byte, 32)
rand.Read(sessionKey[:]) rand.Read(sessionKey)
const maxIter = 500_000 // run with -benchtime 500000x to avoid index out of bounds panic b.Run("plain", func(b *testing.B) {
for name, ep := range encryptionMethods { obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
ep := ep seshConfigOrdered.Obfuscator = obfuscator
b.Run(name, func(b *testing.B) { sesh := MakeSession(0, seshConfigOrdered)
for seshType, seshConfig := range seshConfigs { n, _ := sesh.Obfs(f, obfsBuf)
b.Run(seshType, func(b *testing.B) {
f := f
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
sesh := MakeSession(0, seshConfig)
go func() {
stream, _ := sesh.Accept()
io.Copy(ioutil.Discard, stream)
}()
binaryFrames := [maxIter][]byte{}
for i := 0; i < maxIter; i++ {
obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.obfuscate(&f, obfsBuf, 0)
binaryFrames[i] = obfsBuf[:n]
f.Seq++
}
b.SetBytes(int64(len(f.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.recvDataFromRemote(binaryFrames[i])
}
})
}
})
}
}
func BenchmarkMultiStreamWrite(b *testing.B) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
testPayload := make([]byte, testPayloadLen)
for name, ep := range encryptionMethods {
b.Run(name, func(b *testing.B) {
for seshType, seshConfig := range seshConfigs {
b.Run(seshType, func(b *testing.B) {
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
sesh := MakeSession(0, seshConfig)
sesh.AddConnection(connutil.Discard())
b.ResetTimer()
b.SetBytes(testPayloadLen)
b.RunParallel(func(pb *testing.PB) {
stream, _ := sesh.OpenStream()
for pb.Next() {
stream.Write(testPayload)
}
})
})
}
})
}
}
func BenchmarkLatency(b *testing.B) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
for name, ep := range encryptionMethods {
b.Run(name, func(b *testing.B) {
for seshType, seshConfig := range seshConfigs {
b.Run(seshType, func(b *testing.B) {
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
clientSesh := MakeSession(0, seshConfig)
serverSesh := MakeSession(0, seshConfig)
c, s := net.Pipe()
clientSesh.AddConnection(c)
serverSesh.AddConnection(s)
buf := make([]byte, 64)
clientStream, _ := clientSesh.OpenStream()
clientStream.Write(buf)
serverStream, _ := serverSesh.Accept()
io.ReadFull(serverStream, buf)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
clientStream.Write(buf) sesh.recvDataFromRemote(obfsBuf[:n])
io.ReadFull(serverStream, buf) b.SetBytes(int64(n))
} }
}) })
b.Run("aes-gcm", func(b *testing.B) {
obfuscator, _ := GenerateObfs(E_METHOD_AES_GCM, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.recvDataFromRemote(obfsBuf[:n])
b.SetBytes(int64(n))
} }
}) })
b.Run("chacha20-poly1305", func(b *testing.B) {
obfuscator, _ := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.recvDataFromRemote(obfsBuf[:n])
b.SetBytes(int64(n))
} }
})
} }

View File

@ -6,60 +6,53 @@ import (
"net" "net"
"time" "time"
log "github.com/sirupsen/logrus"
"math"
prand "math/rand"
"sync" "sync"
"sync/atomic" "sync/atomic"
log "github.com/sirupsen/logrus"
) )
var ErrBrokenStream = errors.New("broken stream") var ErrBrokenStream = errors.New("broken stream")
// Stream implements net.Conn. It represents an optionally-ordered, full-duplex, self-contained connection.
// If the session it belongs to runs in ordered mode, it provides ordering guarantee regardless of the underlying
// connection used.
// If the underlying connections the session uses are reliable, Stream is reliable. If they are not, Stream does not
// guarantee reliability.
type Stream struct { type Stream struct {
id uint32 id uint32
session *Session session *Session
// a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't
// been read by the consumer through Read or WriteTo.
recvBuf recvBuffer recvBuf recvBuffer
writingM sync.Mutex // atomic
writingFrame Frame // we do the allocation here to save repeated allocations in Write and ReadFrom nextSendSeq uint64
writingM sync.RWMutex
// atomic // atomic
closed uint32 closed uint32
// When we want order guarantee (i.e. session.Unordered is false), obfsBuf []byte
// we assign each stream a fixed underlying connection.
// If the underlying connections the session uses provide ordering guarantee (most likely TCP),
// recvBuffer (implemented by streamBuffer under ordered mode) will not receive out-of-order packets
// so it won't have to use its priority queue to sort it.
// This is not used in unordered connection mode
assignedConn net.Conn
readFromTimeout time.Duration // we assign each stream a fixed underlying TCP connection to utilise order guarantee provided by TCP itself
// so that frameSorter should have few to none ooo frames to deal with
// overall the streams in a session should be uniformly distributed across all connections
// This is not used in unordered connection mode
assignedConnId uint32
} }
func makeStream(sesh *Session, id uint32) *Stream { func makeStream(sesh *Session, id uint32, assignedConnId uint32) *Stream {
var recvBuf recvBuffer
if sesh.Unordered {
recvBuf = NewDatagramBuffer()
} else {
recvBuf = NewStreamBuffer()
}
stream := &Stream{ stream := &Stream{
id: id, id: id,
session: sesh, session: sesh,
writingFrame: Frame{ recvBuf: recvBuf,
StreamID: id, obfsBuf: make([]byte, 17000), //TODO don't leave this hardcoded
Seq: 0, assignedConnId: assignedConnId,
Closing: closingNothing,
},
}
if sesh.Unordered {
stream.recvBuf = NewDatagramBufferedPipe()
} else {
stream.recvBuf = NewStreamBuffer()
} }
return stream return stream
@ -67,145 +60,129 @@ func makeStream(sesh *Session, id uint32) *Stream {
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
// receive a readily deobfuscated Frame so its payload can later be Read func (s *Stream) writeFrame(frame Frame) error {
func (s *Stream) recvFrame(frame *Frame) error { return s.recvBuf.Write(frame)
toBeClosed, err := s.recvBuf.Write(frame)
if toBeClosed {
err = s.passiveClose()
if errors.Is(err, errRepeatStreamClosing) {
log.Debug(err)
return nil
}
return err
}
return err
} }
// Read implements io.Read // Read implements io.Read
func (s *Stream) Read(buf []byte) (n int, err error) { func (s *Stream) Read(buf []byte) (n int, err error) {
//log.Tracef("attempting to read from stream %v", s.id) //log.Tracef("attempting to read from stream %v", s.id)
if len(buf) == 0 { if len(buf) == 0 {
if s.isClosed() {
return 0, ErrBrokenStream
} else {
return 0, nil return 0, nil
} }
}
n, err = s.recvBuf.Read(buf) n, err = s.recvBuf.Read(buf)
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
if err == io.EOF { if err == io.EOF {
return n, ErrBrokenStream return n, ErrBrokenStream
} }
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
return return
}
func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error {
cipherTextLen, err := s.session.obfuscate(&s.writingFrame, buf, payloadOffsetInBuf)
s.writingFrame.Seq++
if err != nil {
return err
}
_, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConn)
if err != nil {
if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error())
s.session.passiveClose()
}
return err
}
return nil
} }
// Write implements io.Write // Write implements io.Write
func (s *Stream) Write(in []byte) (n int, err error) { func (s *Stream) Write(in []byte) (n int, err error) {
s.writingM.Lock() // RWMutex used here isn't really for RW.
defer s.writingM.Unlock() // we use it to exploit the fact that RLock doesn't create contention.
// The use of RWMutex is so that the stream will not actively close
// in the middle of the execution of Write. This may cause the closing frame
// to be sent before the data frame and cause loss of packet.
//log.Tracef("attempting to write %v bytes to stream %v",len(in),s.id)
s.writingM.RLock()
defer s.writingM.RUnlock()
if s.isClosed() { if s.isClosed() {
return 0, ErrBrokenStream return 0, ErrBrokenStream
} }
for n < len(in) { f := &Frame{
var framePayload []byte StreamID: s.id,
if len(in)-n <= s.session.maxStreamUnitWrite { Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
// if we can fit remaining data of in into one frame Closing: 0,
framePayload = in[n:] Payload: in,
} else {
// if we have to split
if s.session.Unordered {
// but we are not allowed to
err = io.ErrShortBuffer
return
} }
framePayload = in[n : s.session.maxStreamUnitWrite+n]
i, err := s.session.Obfs(f, s.obfsBuf)
if err != nil {
return i, err
} }
s.writingFrame.Payload = framePayload n, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId)
buf := s.session.streamObfsBufPool.Get().(*[]byte) log.Tracef("%v sent to remote through stream %v with err %v", len(in), s.id, err)
err = s.obfuscateAndSend(*buf, 0)
s.session.streamObfsBufPool.Put(buf)
if err != nil { if err != nil {
return return
} }
n += len(framePayload) return len(in), nil
}
return
} }
// ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read // the necessary steps to mark the stream as closed and to release resources
// for readFromTimeout amount of time func (s *Stream) _close() {
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { atomic.StoreUint32(&s.closed, 1)
for { _ = s.recvBuf.Close() // both datagramBuffer and streamBuffer won't return err on Close()
if s.readFromTimeout != 0 {
if rder, ok := r.(net.Conn); !ok {
log.Warn("ReadFrom timeout is set but reader doesn't implement SetReadDeadline")
} else {
rder.SetReadDeadline(time.Now().Add(s.readFromTimeout))
}
}
buf := s.session.streamObfsBufPool.Get().(*[]byte)
read, er := r.Read((*buf)[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite])
if er != nil {
return n, er
}
// the above read may have been unblocked by another goroutine calling stream.Close(), so we need
// to check that here
if s.isClosed() {
return n, ErrBrokenStream
}
s.writingM.Lock()
s.writingFrame.Payload = (*buf)[frameHeaderLength : frameHeaderLength+read]
err = s.obfuscateAndSend(*buf, frameHeaderLength)
s.writingM.Unlock()
s.session.streamObfsBufPool.Put(buf)
if err != nil {
return
}
n += int64(read)
}
} }
func (s *Stream) passiveClose() error { // only close locally. Used when the stream close is notified by the remote
return s.session.closeStream(s, false) func (s *Stream) passiveClose() {
s._close()
s.session.delStream(s.id)
log.Tracef("stream %v passively closed", s.id)
} }
// active close. Close locally and tell the remote that this stream is being closed // active close. Close locally and tell the remote that this stream is being closed
func (s *Stream) Close() error { func (s *Stream) Close() error {
s.writingM.Lock() s.writingM.Lock()
defer s.writingM.Unlock() defer s.writingM.Unlock()
if s.isClosed() {
return errors.New("Already Closed")
}
return s.session.closeStream(s, true) // Notify remote that this stream is closed
prand.Seed(int64(s.id))
padLen := int(math.Floor(prand.Float64()*200 + 300))
pad := make([]byte, padLen)
prand.Read(pad)
f := &Frame{
StreamID: s.id,
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
Closing: 1,
Payload: pad,
}
i, err := s.session.Obfs(f, s.obfsBuf)
if err != nil {
return err
}
_, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId)
if err != nil {
return err
}
s._close()
s.session.delStream(s.id)
log.Tracef("stream %v actively closed", s.id)
return nil
} }
// Same as passiveClose() but no call to session.delStream.
// This is called in session.Close() to avoid mutex deadlock
// We don't notify the remote because session.Close() is always
// called when the session is passively closed
func (s *Stream) closeNoDelMap() {
log.Tracef("stream %v closed by session", s.id)
s._close()
}
// the following functions are purely for implementing net.Conn interface.
// they are not used
var errNotImplemented = errors.New("Not implemented")
func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] } func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] }
func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] } func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] }
func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil }
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d }
var errNotImplemented = errors.New("Not implemented")
// the following functions are purely for implementing net.Conn interface.
// they are not used
// TODO: implement the following // TODO: implement the following
func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented } func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented }
func (s *Stream) SetReadDeadline(t time.Time) error { return errNotImplemented }
func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented } func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented }

View File

@ -7,14 +7,13 @@ package multiplex
// remote side before packet0. Cloak have to therefore sequence the packets so that they // remote side before packet0. Cloak have to therefore sequence the packets so that they
// arrive in order as they were sent by the proxy software // arrive in order as they were sent by the proxy software
// //
// Cloak packets will have a 64-bit sequence number on them, so we know in which order // Cloak packets will have a 32-bit sequence number on them, so we know in which order
// they should be sent to the proxy software. The code in this file provides buffering and sorting. // they should be sent to the proxy software. The code in this file provides buffering and sorting.
import ( import (
"container/heap" "container/heap"
"fmt" "fmt"
"sync" "sync"
"time"
) )
type sorterHeap []*Frame type sorterHeap []*Frame
@ -45,56 +44,56 @@ type streamBuffer struct {
recvM sync.Mutex recvM sync.Mutex
nextRecvSeq uint64 nextRecvSeq uint64
rev int
sh sorterHeap sh sorterHeap
buf *streamBufferedPipe buf *bufferedPipe
} }
// streamBuffer is a wrapper around streamBufferedPipe.
// Its main function is to sort frames in order, and wait for frames to arrive
// if they have arrived out-of-order. Then it writes the payload of frames into
// a streamBufferedPipe.
func NewStreamBuffer() *streamBuffer { func NewStreamBuffer() *streamBuffer {
sb := &streamBuffer{ sb := &streamBuffer{
sh: []*Frame{}, sh: []*Frame{},
buf: NewStreamBufferedPipe(), rev: 0,
buf: NewBufferedPipe(),
} }
return sb return sb
} }
func (sb *streamBuffer) Write(f *Frame) (toBeClosed bool, err error) { // recvNewFrame is a forever running loop which receives frames unordered,
// cache and order them and send them into sortedBufCh
func (sb *streamBuffer) Write(f Frame) error {
sb.recvM.Lock() sb.recvM.Lock()
defer sb.recvM.Unlock() defer sb.recvM.Unlock()
// when there'fs no ooo packages in heap and we receive the next package in order // when there'fs no ooo packages in heap and we receive the next package in order
if len(sb.sh) == 0 && f.Seq == sb.nextRecvSeq { if len(sb.sh) == 0 && f.Seq == sb.nextRecvSeq {
if f.Closing != closingNothing { if f.Closing == 1 {
return true, nil sb.buf.Close()
return nil
} else { } else {
sb.buf.Write(f.Payload) sb.buf.Write(f.Payload)
sb.nextRecvSeq += 1 sb.nextRecvSeq += 1
} }
return false, nil return nil
} }
if f.Seq < sb.nextRecvSeq { if f.Seq < sb.nextRecvSeq {
return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq) return fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
} }
saved := *f heap.Push(&sb.sh, &f)
saved.Payload = make([]byte, len(f.Payload))
copy(saved.Payload, f.Payload)
heap.Push(&sb.sh, &saved)
// Keep popping from the heap until empty or to the point that the wanted seq was not received // Keep popping from the heap until empty or to the point that the wanted seq was not received
for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq { for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq {
f = heap.Pop(&sb.sh).(*Frame) f = *heap.Pop(&sb.sh).(*Frame)
if f.Closing != closingNothing { if f.Closing == 1 {
return true, nil // empty data indicates closing signal
sb.buf.Close()
return nil
} else { } else {
sb.buf.Write(f.Payload) sb.buf.Write(f.Payload)
sb.nextRecvSeq += 1 sb.nextRecvSeq += 1
} }
} }
return false, nil return nil
} }
func (sb *streamBuffer) Read(buf []byte) (int, error) { func (sb *streamBuffer) Read(buf []byte) (int, error) {
@ -102,10 +101,5 @@ func (sb *streamBuffer) Read(buf []byte) (int, error) {
} }
func (sb *streamBuffer) Close() error { func (sb *streamBuffer) Close() error {
sb.recvM.Lock()
defer sb.recvM.Unlock()
return sb.buf.Close() return sb.buf.Close()
} }
func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) }

View File

@ -2,7 +2,7 @@ package multiplex
import ( import (
"encoding/binary" "encoding/binary"
"io" "time"
//"log" //"log"
"sort" "sort"
@ -21,11 +21,14 @@ func TestRecvNewFrame(t *testing.T) {
for _, n := range set { for _, n := range set {
bu64 := make([]byte, 8) bu64 := make([]byte, 8)
binary.BigEndian.PutUint64(bu64, n) binary.BigEndian.PutUint64(bu64, n)
sb.Write(&Frame{ frame := Frame{
Seq: n, Seq: n,
Payload: bu64, Payload: bu64,
})
} }
sb.Write(frame)
}
time.Sleep(100 * time.Millisecond)
var sortedResult []uint64 var sortedResult []uint64
for x := 0; x < len(set); x++ { for x := 0; x < len(set); x++ {
@ -42,7 +45,7 @@ func TestRecvNewFrame(t *testing.T) {
copy(targetSorted, set) copy(targetSorted, set)
sort.Slice(targetSorted, func(i, j int) bool { return targetSorted[i] < targetSorted[j] }) sort.Slice(targetSorted, func(i, j int) bool { return targetSorted[i] < targetSorted[j] })
for i := range targetSorted { for i, _ := range targetSorted {
if sortedResult[i] != targetSorted[i] { if sortedResult[i] != targetSorted[i] {
goto fail goto fail
} }
@ -69,23 +72,3 @@ func TestRecvNewFrame(t *testing.T) {
test(outOfOrder2, t) test(outOfOrder2, t)
}) })
} }
func TestStreamBuffer_RecvThenClose(t *testing.T) {
const testDataLen = 128
sb := NewStreamBuffer()
testData := make([]byte, testDataLen)
testFrame := Frame{
StreamID: 0,
Seq: 0,
Closing: 0,
Payload: testData,
}
sb.Write(&testFrame)
sb.Close()
readBuf := make([]byte, testDataLen)
_, err := io.ReadFull(sb, readBuf)
if err != nil {
t.Error(err)
}
}

View File

@ -1,93 +0,0 @@
package multiplex
import (
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const readBlockTime = 500 * time.Millisecond
func TestPipeRW(t *testing.T) {
pipe := NewStreamBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
n, err := pipe.Write(b)
assert.NoError(t, err, "simple write")
assert.Equal(t, len(b), n, "number of bytes written")
b2 := make([]byte, len(b))
n, err = pipe.Read(b2)
assert.NoError(t, err, "simple read")
assert.Equal(t, len(b), n, "number of bytes read")
assert.Equal(t, b, b2)
}
func TestReadBlock(t *testing.T) {
pipe := NewStreamBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
go func() {
time.Sleep(readBlockTime)
pipe.Write(b)
}()
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
assert.NoError(t, err, "blocked read")
assert.Equal(t, len(b), n, "number of bytes read after block")
assert.Equal(t, b, b2)
}
func TestPartialRead(t *testing.T) {
pipe := NewStreamBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(b)
b1 := make([]byte, 1)
n, err := pipe.Read(b1)
assert.NoError(t, err, "partial read of 1")
assert.Equal(t, len(b1), n, "number of bytes in partial read of 1")
assert.Equal(t, b[0], b1[0])
b2 := make([]byte, 2)
n, err = pipe.Read(b2)
assert.NoError(t, err, "partial read of 2")
assert.Equal(t, len(b2), n, "number of bytes in partial read of 2")
assert.Equal(t, b[1:], b2)
}
func TestReadAfterClose(t *testing.T) {
pipe := NewStreamBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(b)
b2 := make([]byte, len(b))
pipe.Close()
n, err := pipe.Read(b2)
assert.NoError(t, err, "simple read")
assert.Equal(t, len(b), n, "number of bytes read")
assert.Equal(t, b, b2)
}
func BenchmarkBufferedPipe_RW(b *testing.B) {
const PAYLOAD_LEN = 1000
testData := make([]byte, PAYLOAD_LEN)
rand.Read(testData)
pipe := NewStreamBufferedPipe()
smallBuf := make([]byte, PAYLOAD_LEN-10)
go func() {
for {
pipe.Read(smallBuf)
}
}()
b.SetBytes(int64(len(testData)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
pipe.Write(testData)
}
}

View File

@ -1,260 +1,80 @@
package multiplex package multiplex
import ( import (
"bufio"
"bytes" "bytes"
"io" "github.com/cbeuw/Cloak/internal/util"
"io/ioutil"
"math/rand" "math/rand"
"net"
"testing" "testing"
"time" "time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/stretchr/testify/assert"
"github.com/cbeuw/connutil"
) )
const payloadLen = 1000 func setupSesh(unordered bool) *Session {
sessionKey := make([]byte, 32)
rand.Read(sessionKey)
obfuscator, _ := GenerateObfs(0x00, sessionKey, true)
var emptyKey [32]byte seshConfig := &SessionConfig{
func setupSesh(unordered bool, key [32]byte, encryptionMethod byte) *Session {
obfuscator, _ := MakeObfuscator(encryptionMethod, key)
seshConfig := SessionConfig{
Obfuscator: obfuscator, Obfuscator: obfuscator,
Valve: nil, Valve: nil,
UnitRead: util.ReadTLS,
Unordered: unordered, Unordered: unordered,
} }
return MakeSession(0, seshConfig) return MakeSession(0, seshConfig)
} }
func BenchmarkStream_Write_Ordered(b *testing.B) { type blackhole struct {
hole := connutil.Discard() hole *bufio.Writer
var sessionKey [32]byte
rand.Read(sessionKey[:])
const testDataLen = 65536
testData := make([]byte, testDataLen)
rand.Read(testData)
eMethods := map[string]byte{
"plain": EncryptionMethodPlain,
"chacha20-poly1305": EncryptionMethodChaha20Poly1305,
"aes-256-gcm": EncryptionMethodAES256GCM,
"aes-128-gcm": EncryptionMethodAES128GCM,
}
for name, method := range eMethods {
b.Run(name, func(b *testing.B) {
sesh := setupSesh(false, sessionKey, method)
sesh.AddConnection(hole)
stream, _ := sesh.OpenStream()
b.SetBytes(testDataLen)
b.ResetTimer()
for i := 0; i < b.N; i++ {
stream.Write(testData)
}
})
}
} }
func TestStream_Write(t *testing.T) { func newBlackHole() *blackhole { return &blackhole{hole: bufio.NewWriter(ioutil.Discard)} }
hole := connutil.Discard() func (b *blackhole) Read([]byte) (int, error) {
var sessionKey [32]byte time.Sleep(1 * time.Hour)
rand.Read(sessionKey[:]) return 0, nil
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain) }
func (b *blackhole) Write(in []byte) (int, error) { return b.hole.Write(in) }
func (b *blackhole) Close() error { return nil }
func (b *blackhole) LocalAddr() net.Addr {
ret, _ := net.ResolveTCPAddr("tcp", "127.0.0.1")
return ret
}
func (b *blackhole) RemoteAddr() net.Addr {
ret, _ := net.ResolveTCPAddr("tcp", "127.0.0.1")
return ret
}
func (b *blackhole) SetDeadline(t time.Time) error { return nil }
func (b *blackhole) SetReadDeadline(t time.Time) error { return nil }
func (b *blackhole) SetWriteDeadline(t time.Time) error { return nil }
func BenchmarkStream_Write_Ordered(b *testing.B) {
const PAYLOAD_LEN = 1000
hole := newBlackHole()
sesh := setupSesh(false)
sesh.AddConnection(hole) sesh.AddConnection(hole)
testData := make([]byte, payloadLen) testData := make([]byte, PAYLOAD_LEN)
rand.Read(testData) rand.Read(testData)
stream, _ := sesh.OpenStream() stream, _ := sesh.OpenStream()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := stream.Write(testData) _, err := stream.Write(testData)
if err != nil { if err != nil {
t.Error( b.Error(
"For", "stream write", "For", "stream write",
"got", err, "got", err,
) )
} }
b.SetBytes(PAYLOAD_LEN)
}
} }
func TestStream_WriteSync(t *testing.T) { func BenchmarkStream_Read_Ordered(b *testing.B) {
// Close calls made after write MUST have a higher seq sesh := setupSesh(false)
var sessionKey [32]byte const PAYLOAD_LEN = 1000
rand.Read(sessionKey[:]) testPayload := make([]byte, PAYLOAD_LEN)
clientSesh := setupSesh(false, sessionKey, EncryptionMethodPlain) rand.Read(testPayload)
serverSesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
w, r := connutil.AsyncPipe()
clientSesh.AddConnection(common.NewTLSConn(w))
serverSesh.AddConnection(common.NewTLSConn(r))
testData := make([]byte, payloadLen)
rand.Read(testData)
t.Run("test single", func(t *testing.T) {
go func() {
stream, _ := clientSesh.OpenStream()
stream.Write(testData)
stream.Close()
}()
recvBuf := make([]byte, payloadLen)
serverStream, _ := serverSesh.Accept()
_, err := io.ReadFull(serverStream, recvBuf)
if err != nil {
t.Error(err)
}
})
t.Run("test multiple", func(t *testing.T) {
const numStreams = 100
for i := 0; i < numStreams; i++ {
go func() {
stream, _ := clientSesh.OpenStream()
stream.Write(testData)
stream.Close()
}()
}
for i := 0; i < numStreams; i++ {
recvBuf := make([]byte, payloadLen)
serverStream, _ := serverSesh.Accept()
_, err := io.ReadFull(serverStream, recvBuf)
if err != nil {
t.Error(err)
}
}
})
}
func TestStream_Close(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
testPayload := []byte{42, 42, 42}
dataFrame := &Frame{
1,
0,
0,
testPayload,
}
t.Run("active closing", func(t *testing.T) {
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
rawConn, rawWritingEnd := connutil.AsyncPipe()
sesh.AddConnection(common.NewTLSConn(rawConn))
writingEnd := common.NewTLSConn(rawWritingEnd)
obfsBuf := make([]byte, 512)
i, _ := sesh.obfuscate(dataFrame, obfsBuf, 0)
_, err := writingEnd.Write(obfsBuf[:i])
if err != nil {
t.Error("failed to write from remote end")
}
stream, err := sesh.Accept()
if err != nil {
t.Error("failed to accept stream", err)
return
}
time.Sleep(500 * time.Millisecond)
err = stream.Close()
if err != nil {
t.Error("failed to actively close stream", err)
return
}
sesh.streamsM.Lock()
if s, _ := sesh.streams[stream.(*Stream).id]; s != nil {
sesh.streamsM.Unlock()
t.Error("stream still exists")
return
}
sesh.streamsM.Unlock()
readBuf := make([]byte, len(testPayload))
_, err = io.ReadFull(stream, readBuf)
if err != nil {
t.Errorf("cannot read resiual data: %v", err)
}
if !bytes.Equal(readBuf, testPayload) {
t.Errorf("read wrong data")
}
})
t.Run("passive closing", func(t *testing.T) {
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
rawConn, rawWritingEnd := connutil.AsyncPipe()
sesh.AddConnection(common.NewTLSConn(rawConn))
writingEnd := common.NewTLSConn(rawWritingEnd)
obfsBuf := make([]byte, 512)
i, err := sesh.obfuscate(dataFrame, obfsBuf, 0)
if err != nil {
t.Errorf("failed to obfuscate frame %v", err)
}
_, err = writingEnd.Write(obfsBuf[:i])
if err != nil {
t.Error("failed to write from remote end")
}
stream, err := sesh.Accept()
if err != nil {
t.Error("failed to accept stream", err)
return
}
closingFrame := &Frame{
1,
dataFrame.Seq + 1,
closingStream,
testPayload,
}
i, err = sesh.obfuscate(closingFrame, obfsBuf, 0)
if err != nil {
t.Errorf("failed to obfuscate frame %v", err)
}
_, err = writingEnd.Write(obfsBuf[:i])
if err != nil {
t.Errorf("failed to write from remote end %v", err)
}
closingFrameDup := &Frame{
1,
dataFrame.Seq + 2,
closingStream,
testPayload,
}
i, err = sesh.obfuscate(closingFrameDup, obfsBuf, 0)
if err != nil {
t.Errorf("failed to obfuscate frame %v", err)
}
_, err = writingEnd.Write(obfsBuf[:i])
if err != nil {
t.Errorf("failed to write from remote end %v", err)
}
readBuf := make([]byte, len(testPayload))
_, err = io.ReadFull(stream, readBuf)
if err != nil {
t.Errorf("can't read residual data %v", err)
}
assert.Eventually(t, func() bool {
sesh.streamsM.Lock()
s, _ := sesh.streams[stream.(*Stream).id]
sesh.streamsM.Unlock()
return s == nil
}, time.Second, 10*time.Millisecond, "streams still exists")
})
}
func TestStream_Read(t *testing.T) {
seshes := map[string]bool{
"ordered": false,
"unordered": true,
}
testPayload := []byte{42, 42, 42}
const smallPayloadLen = 3
f := &Frame{ f := &Frame{
1, 1,
@ -263,21 +83,84 @@ func TestStream_Read(t *testing.T) {
testPayload, testPayload,
} }
var streamID uint32 obfsBuf := make([]byte, 17000)
for name, unordered := range seshes { l, _ := net.Listen("tcp", "127.0.0.1:0")
sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain) go func() {
rawConn, rawWritingEnd := connutil.AsyncPipe() // potentially bottlenecked here rather than the actual stream read throughput
sesh.AddConnection(common.NewTLSConn(rawConn)) conn, _ := net.Dial("tcp", l.Addr().String())
writingEnd := common.NewTLSConn(rawWritingEnd) for {
t.Run(name, func(t *testing.T) { i, _ := sesh.Obfs(f, obfsBuf)
f.Seq += 1
_, err := conn.Write(obfsBuf[:i])
if err != nil {
b.Error("cannot write to connection", err)
}
}
}()
conn, _ := l.Accept()
sesh.AddConnection(conn)
stream, err := sesh.Accept()
if err != nil {
b.Error("failed to accept stream", err)
}
//time.Sleep(5*time.Second) // wait for buffer to fill up
readBuf := make([]byte, PAYLOAD_LEN)
b.ResetTimer()
for j := 0; j < b.N; j++ {
n, err := stream.Read(readBuf)
if !bytes.Equal(readBuf, testPayload) {
b.Error("paylod not equal")
}
b.SetBytes(int64(n))
if err != nil {
b.Error(err)
}
}
}
func TestStream_Read(t *testing.T) {
sesh := setupSesh(false)
testPayload := []byte{42, 42, 42}
const PAYLOAD_LEN = 3
f := &Frame{
1,
0,
0,
testPayload,
}
ch := make(chan []byte)
l, _ := net.Listen("tcp", "127.0.0.1:0")
go func() {
conn, _ := net.Dial("tcp", l.Addr().String())
for {
data := <-ch
_, err := conn.Write(data)
if err != nil {
t.Error("cannot write to connection", err)
return
}
}
}()
conn, _ := l.Accept()
sesh.AddConnection(conn)
var streamID uint32
buf := make([]byte, 10) buf := make([]byte, 10)
obfsBuf := make([]byte, 512) obfsBuf := make([]byte, 512)
t.Run("Plain read", func(t *testing.T) { t.Run("Plain read", func(t *testing.T) {
f.StreamID = streamID f.StreamID = streamID
i, _ := sesh.obfuscate(f, obfsBuf, 0) i, _ := sesh.Obfs(f, obfsBuf)
streamID++ streamID++
writingEnd.Write(obfsBuf[:i]) ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, err := sesh.Accept() stream, err := sesh.Accept()
if err != nil { if err != nil {
t.Error("failed to accept stream", err) t.Error("failed to accept stream", err)
@ -288,8 +171,8 @@ func TestStream_Read(t *testing.T) {
t.Error("failed to read", err) t.Error("failed to read", err)
return return
} }
if i != smallPayloadLen { if i != PAYLOAD_LEN {
t.Errorf("expected read %v, got %v", smallPayloadLen, i) t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i)
return return
} }
if !bytes.Equal(buf[:i], testPayload) { if !bytes.Equal(buf[:i], testPayload) {
@ -300,34 +183,43 @@ func TestStream_Read(t *testing.T) {
}) })
t.Run("Nil buf", func(t *testing.T) { t.Run("Nil buf", func(t *testing.T) {
f.StreamID = streamID f.StreamID = streamID
i, _ := sesh.obfuscate(f, obfsBuf, 0) i, _ := sesh.Obfs(f, obfsBuf)
streamID++ streamID++
writingEnd.Write(obfsBuf[:i]) ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, _ := sesh.Accept() stream, _ := sesh.Accept()
i, err := stream.Read(nil) i, err := stream.Read(nil)
if i != 0 || err != nil { if i != 0 || err != nil {
t.Error("expecting", 0, nil, t.Error("expecting", 0, nil,
"got", i, err) "got", i, err)
} }
stream.Close()
i, err = stream.Read(nil)
if i != 0 || err != ErrBrokenStream {
t.Error("expecting", 0, ErrBrokenStream,
"got", i, err)
}
}) })
t.Run("Read after stream close", func(t *testing.T) { t.Run("Read after stream close", func(t *testing.T) {
f.StreamID = streamID f.StreamID = streamID
i, _ := sesh.obfuscate(f, obfsBuf, 0) i, _ := sesh.Obfs(f, obfsBuf)
streamID++ streamID++
writingEnd.Write(obfsBuf[:i]) ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, _ := sesh.Accept() stream, _ := sesh.Accept()
time.Sleep(500 * time.Millisecond)
stream.Close() stream.Close()
i, err := stream.Read(buf)
_, err := io.ReadFull(stream, buf[:smallPayloadLen])
if err != nil { if err != nil {
t.Errorf("cannot read residual data: %v", err) t.Error("failed to read", err)
} }
if !bytes.Equal(buf[:smallPayloadLen], testPayload) { if i != PAYLOAD_LEN {
t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i)
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload, t.Error("expected", testPayload,
"got", buf[:smallPayloadLen]) "got", buf[:i])
} }
_, err = stream.Read(buf) _, err = stream.Read(buf)
if err == nil { if err == nil {
@ -337,21 +229,22 @@ func TestStream_Read(t *testing.T) {
}) })
t.Run("Read after session close", func(t *testing.T) { t.Run("Read after session close", func(t *testing.T) {
f.StreamID = streamID f.StreamID = streamID
i, _ := sesh.obfuscate(f, obfsBuf, 0) i, _ := sesh.Obfs(f, obfsBuf)
streamID++ streamID++
writingEnd.Write(obfsBuf[:i]) ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, _ := sesh.Accept() stream, _ := sesh.Accept()
time.Sleep(500 * time.Millisecond)
sesh.Close() sesh.Close()
_, err := io.ReadFull(stream, buf[:smallPayloadLen]) i, err := stream.Read(buf)
if err != nil { if err != nil {
t.Errorf("cannot read resiual data: %v", err) t.Error("failed to read", err)
} }
if !bytes.Equal(buf[:smallPayloadLen], testPayload) { if i != PAYLOAD_LEN {
t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i)
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload, t.Error("expected", testPayload,
"got", buf[:smallPayloadLen]) "got", buf[:i])
} }
_, err = stream.Read(buf) _, err = stream.Read(buf)
if err == nil { if err == nil {
@ -359,30 +252,132 @@ func TestStream_Read(t *testing.T) {
"got nil error") "got nil error")
} }
}) })
})
}
} }
func TestStream_SetReadFromTimeout(t *testing.T) { func TestStream_UnorderedRead(t *testing.T) {
seshes := map[string]*Session{ sesh := setupSesh(true)
"ordered": setupSesh(false, emptyKey, EncryptionMethodPlain), testPayload := []byte{42, 42, 42}
"unordered": setupSesh(true, emptyKey, EncryptionMethodPlain), const PAYLOAD_LEN = 3
f := &Frame{
1,
0,
0,
testPayload,
} }
for name, sesh := range seshes {
t.Run(name, func(t *testing.T) { ch := make(chan []byte)
stream, _ := sesh.OpenStream() l, _ := net.Listen("tcp", "127.0.0.1:0")
stream.SetReadFromTimeout(100 * time.Millisecond)
done := make(chan struct{})
go func() { go func() {
stream.ReadFrom(connutil.Discard()) conn, _ := net.Dial("tcp", l.Addr().String())
done <- struct{}{} for {
data := <-ch
_, err := conn.Write(data)
if err != nil {
t.Error("cannot write to connection", err)
}
}
}() }()
select { conn, _ := l.Accept()
case <-done: sesh.AddConnection(conn)
return
case <-time.After(500 * time.Millisecond): var streamID uint32
t.Error("didn't timeout") buf := make([]byte, 10)
obfsBuf := make([]byte, 512)
t.Run("Plain read", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf)
streamID++
ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, err := sesh.Accept()
if err != nil {
t.Error("failed to accept stream", err)
}
i, err = stream.Read(buf)
if err != nil {
t.Error("failed to read", err)
}
if i != PAYLOAD_LEN {
t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i)
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload,
"got", buf[:i])
} }
}) })
t.Run("Nil buf", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf)
streamID++
ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, _ := sesh.Accept()
i, err := stream.Read(nil)
if i != 0 || err != nil {
t.Error("expecting", 0, nil,
"got", i, err)
} }
stream.Close()
i, err = stream.Read(nil)
if i != 0 || err != ErrBrokenStream {
t.Error("expecting", 0, ErrBrokenStream,
"got", i, err)
}
})
t.Run("Read after stream close", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf)
streamID++
ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, _ := sesh.Accept()
stream.Close()
i, err := stream.Read(buf)
if err != nil {
t.Error("failed to read", err)
}
if i != PAYLOAD_LEN {
t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i)
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload,
"got", buf[:i])
}
_, err = stream.Read(buf)
if err == nil {
t.Error("expecting error", ErrBrokenStream,
"got nil error")
}
})
t.Run("Read after session close", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.Obfs(f, obfsBuf)
streamID++
ch <- obfsBuf[:i]
time.Sleep(100 * time.Microsecond)
stream, _ := sesh.Accept()
sesh.Close()
i, err := stream.Read(buf)
if err != nil {
t.Error("failed to read", err)
}
if i != PAYLOAD_LEN {
t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i)
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload,
"got", buf[:i])
}
_, err = stream.Read(buf)
if err == nil {
t.Error("expecting error", ErrBrokenStream,
"got nil error")
}
})
} }

View File

@ -2,159 +2,165 @@ package multiplex
import ( import (
"errors" "errors"
"github.com/cbeuw/Cloak/internal/common"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"math/rand/v2" "math/rand"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
) )
type switchboardStrategy int
const ( const (
fixedConnMapping switchboardStrategy = iota FIXED_CONN_MAPPING switchboardStrategy = iota
uniformSpread UNIFORM_SPREAD
) )
// switchboard represents the connection pool. It is responsible for managing type switchboardConfig struct {
// transport-layer connections between client and server. Valve
// It has several purposes: constantly receiving incoming data from all connections strategy switchboardStrategy
// and pass them to Session.recvDataFromRemote(); accepting data through }
// switchboard.send(), in which it selects a connection according to its
// switchboardStrategy and send the data off using that; and counting, as well as // switchboard is responsible for keeping the reference of TCP connections between client and server
// rate limiting, data received and sent through its Valve.
type switchboard struct { type switchboard struct {
session *Session session *Session
valve Valve *switchboardConfig
strategy switchboardStrategy
conns sync.Map connsM sync.RWMutex
connsCount uint32 conns map[uint32]net.Conn
randPool sync.Pool nextConnId uint32
broken uint32 broken uint32
} }
func makeSwitchboard(sesh *Session) *switchboard { func makeSwitchboard(sesh *Session, config *switchboardConfig) *switchboard {
// rates are uint64 because in the usermanager we want the bandwidth to be atomically
// operated (so that the bandwidth can change on the fly).
sb := &switchboard{ sb := &switchboard{
session: sesh, session: sesh,
strategy: uniformSpread, switchboardConfig: config,
valve: sesh.Valve, conns: make(map[uint32]net.Conn),
randPool: sync.Pool{New: func() interface{} {
var state [32]byte
common.CryptoRandRead(state[:])
return rand.New(rand.NewChaCha8(state))
}},
} }
return sb return sb
} }
var errNilOptimum = errors.New("The optimal connection is nil")
var errBrokenSwitchboard = errors.New("the switchboard is broken") var errBrokenSwitchboard = errors.New("the switchboard is broken")
func (sb *switchboard) addConn(conn net.Conn) { func (sb *switchboard) addConn(conn net.Conn) {
connId := atomic.AddUint32(&sb.connsCount, 1) - 1 connId := atomic.AddUint32(&sb.nextConnId, 1) - 1
sb.conns.Store(connId, conn) sb.connsM.Lock()
go sb.deplex(conn) sb.conns[connId] = conn
sb.connsM.Unlock()
go sb.deplex(connId, conn)
} }
// a pointer to assignedConn is passed here so that the switchboard can reassign it if that conn isn't usable func (sb *switchboard) removeConn(connId uint32) {
func (sb *switchboard) send(data []byte, assignedConn *net.Conn) (n int, err error) { sb.connsM.Lock()
sb.valve.txWait(len(data)) delete(sb.conns, connId)
if atomic.LoadUint32(&sb.broken) == 1 { remaining := len(sb.conns)
return 0, errBrokenSwitchboard sb.connsM.Unlock()
if remaining == 0 {
atomic.StoreUint32(&sb.broken, 1)
sb.session.SetTerminalMsg("no underlying connection left")
sb.session.Close()
} }
var conn net.Conn
switch sb.strategy {
case uniformSpread:
conn, err = sb.pickRandConn()
if err != nil {
return 0, errBrokenSwitchboard
}
n, err = conn.Write(data)
if err != nil {
sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
sb.session.passiveClose()
return n, err
}
case fixedConnMapping:
// FIXME: this strategy has a tendency to cause a TLS conn socket buffer to fill up,
// which is a problem when multiple streams are mapped to the same conn, resulting
// in all such streams being blocked.
conn = *assignedConn
if conn == nil {
conn, err = sb.pickRandConn()
if err != nil {
sb.session.SetTerminalMsg("failed to pick a connection " + err.Error())
sb.session.passiveClose()
return 0, err
}
*assignedConn = conn
}
n, err = conn.Write(data)
if err != nil {
sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
sb.session.passiveClose()
return n, err
}
default:
return 0, errors.New("unsupported traffic distribution strategy")
}
sb.valve.AddTx(int64(n))
return n, nil
} }
// returns a random conn. This function can be called concurrently. // a pointer to connId is passed here so that the switchboard can reassign it
func (sb *switchboard) pickRandConn() (net.Conn, error) { func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) {
if atomic.LoadUint32(&sb.broken) == 1 { sb.Valve.txWait(len(data))
return nil, errBrokenSwitchboard sb.connsM.RLock()
} defer sb.connsM.RUnlock()
if sb.strategy == UNIFORM_SPREAD {
connsCount := atomic.LoadUint32(&sb.connsCount) randConnId := rand.Intn(len(sb.conns))
if connsCount == 0 { conn, ok := sb.conns[uint32(randConnId)]
return nil, errBrokenSwitchboard
}
randReader := sb.randPool.Get().(*rand.Rand)
connId := randReader.Uint32N(connsCount)
sb.randPool.Put(randReader)
ret, ok := sb.conns.Load(connId)
if !ok { if !ok {
log.Errorf("failed to get conn %d", connId) return 0, errBrokenSwitchboard
return nil, errBrokenSwitchboard } else {
n, err = conn.Write(data)
sb.AddTx(int64(n))
return
} }
return ret.(net.Conn), nil } else {
var conn net.Conn
conn, ok := sb.conns[*connId]
if ok {
n, err = conn.Write(data)
sb.AddTx(int64(n))
return
} else {
// do not call assignRandomConn() here.
// we'll have to do connsM.RLock() after we get a new connId from assignRandomConn, in order to
// get the new conn through conns[newConnId]
// however between connsM.RUnlock() in assignRandomConn and our call to connsM.RLock(), things may happen.
// in particular if newConnId is removed between the RUnlock and RLock, conns[newConnId] will return
// a nil pointer. To prevent this we must get newConnId and the reference to conn itself in one single mutex
// protection
if atomic.LoadUint32(&sb.broken) == 1 || len(sb.conns) == 0 {
return 0, errBrokenSwitchboard
}
r := rand.Intn(len(sb.conns))
var c int
for newConnId := range sb.conns {
if r == c {
connId = &newConnId
conn, _ = sb.conns[newConnId]
n, err = conn.Write(data)
sb.AddTx(int64(n))
return
}
c++
}
return 0, errBrokenSwitchboard
}
}
}
// returns a random connId
func (sb *switchboard) assignRandomConn() (uint32, error) {
sb.connsM.RLock()
defer sb.connsM.RUnlock()
if atomic.LoadUint32(&sb.broken) == 1 || len(sb.conns) == 0 {
return 0, errBrokenSwitchboard
}
r := rand.Intn(len(sb.conns))
var c int
for connId := range sb.conns {
if r == c {
return connId, nil
}
c++
}
return 0, errBrokenSwitchboard
} }
// actively triggered by session.Close() // actively triggered by session.Close()
func (sb *switchboard) closeAll() { func (sb *switchboard) closeAll() {
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) { if atomic.SwapUint32(&sb.broken, 1) == 1 {
return return
} }
atomic.StoreUint32(&sb.connsCount, 0) sb.connsM.RLock()
sb.conns.Range(func(_, conn interface{}) bool { for key, conn := range sb.conns {
conn.(net.Conn).Close() conn.Close()
sb.conns.Delete(conn) delete(sb.conns, key)
return true }
}) sb.connsM.RUnlock()
} }
// deplex function costantly reads from a TCP connection // deplex function costantly reads from a TCP connection
func (sb *switchboard) deplex(conn net.Conn) { func (sb *switchboard) deplex(connId uint32, conn net.Conn) {
defer conn.Close() buf := make([]byte, 20480)
buf := make([]byte, sb.session.connReceiveBufferSize)
for { for {
n, err := conn.Read(buf) n, err := sb.session.UnitRead(conn, buf)
sb.valve.rxWait(n) sb.rxWait(n)
sb.valve.AddRx(int64(n)) sb.Valve.AddRx(int64(n))
if err != nil { if err != nil {
log.Debugf("a connection for session %v has closed: %v", sb.session.id, err) log.Debugf("a connection for session %v has closed: %v", sb.session.id, err)
sb.session.SetTerminalMsg("a connection has dropped unexpectedly") go conn.Close()
sb.session.passiveClose() sb.removeConn(connId)
return return
} }

View File

@ -1,100 +1,49 @@
package multiplex package multiplex
import ( import (
"github.com/cbeuw/Cloak/internal/util"
"math/rand" "math/rand"
"sync"
"sync/atomic"
"testing" "testing"
"time"
"github.com/cbeuw/connutil"
"github.com/stretchr/testify/assert"
) )
func TestSwitchboard_Send(t *testing.T) {
doTest := func(seshConfig SessionConfig) {
sesh := MakeSession(0, seshConfig)
hole0 := connutil.Discard()
sesh.sb.addConn(hole0)
conn, err := sesh.sb.pickRandConn()
if err != nil {
t.Error("failed to get a random conn", err)
return
}
data := make([]byte, 1000)
rand.Read(data)
_, err = sesh.sb.send(data, &conn)
if err != nil {
t.Error(err)
return
}
hole1 := connutil.Discard()
sesh.sb.addConn(hole1)
conn, err = sesh.sb.pickRandConn()
if err != nil {
t.Error("failed to get a random conn", err)
return
}
_, err = sesh.sb.send(data, &conn)
if err != nil {
t.Error(err)
return
}
conn, err = sesh.sb.pickRandConn()
if err != nil {
t.Error("failed to get a random conn", err)
return
}
_, err = sesh.sb.send(data, &conn)
if err != nil {
t.Error(err)
return
}
}
t.Run("Ordered", func(t *testing.T) {
seshConfig := SessionConfig{
Unordered: false,
}
doTest(seshConfig)
})
t.Run("Unordered", func(t *testing.T) {
seshConfig := SessionConfig{
Unordered: true,
}
doTest(seshConfig)
})
}
func BenchmarkSwitchboard_Send(b *testing.B) { func BenchmarkSwitchboard_Send(b *testing.B) {
hole := connutil.Discard() seshConfig := &SessionConfig{
seshConfig := SessionConfig{} Obfuscator: nil,
Valve: nil,
UnitRead: util.ReadTLS,
}
sesh := MakeSession(0, seshConfig) sesh := MakeSession(0, seshConfig)
hole := newBlackHole()
sesh.sb.addConn(hole) sesh.sb.addConn(hole)
conn, err := sesh.sb.pickRandConn() connId, err := sesh.sb.assignRandomConn()
if err != nil { if err != nil {
b.Error("failed to get a random conn", err) b.Error("failed to get a random conn", err)
return return
} }
data := make([]byte, 1000) data := make([]byte, 1000)
rand.Read(data) rand.Read(data)
b.SetBytes(int64(len(data)))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
sesh.sb.send(data, &conn) n, err := sesh.sb.send(data, &connId)
if err != nil {
b.Error(err)
return
}
b.SetBytes(int64(n))
} }
} }
func TestSwitchboard_TxCredit(t *testing.T) { func TestSwitchboard_TxCredit(t *testing.T) {
seshConfig := SessionConfig{ seshConfig := &SessionConfig{
Obfuscator: nil,
Valve: MakeValve(1<<20, 1<<20), Valve: MakeValve(1<<20, 1<<20),
UnitRead: util.ReadTLS,
} }
sesh := MakeSession(0, seshConfig) sesh := MakeSession(0, seshConfig)
hole := connutil.Discard() hole := newBlackHole()
sesh.sb.addConn(hole) sesh.sb.addConn(hole)
conn, err := sesh.sb.pickRandConn() connId, err := sesh.sb.assignRandomConn()
if err != nil { if err != nil {
t.Error("failed to get a random conn", err) t.Error("failed to get a random conn", err)
return return
@ -102,10 +51,10 @@ func TestSwitchboard_TxCredit(t *testing.T) {
data := make([]byte, 1000) data := make([]byte, 1000)
rand.Read(data) rand.Read(data)
t.Run("fixed conn mapping", func(t *testing.T) { t.Run("FIXED CONN MAPPING", func(t *testing.T) {
*sesh.sb.valve.(*LimitedValve).tx = 0 *sesh.sb.Valve.(*LimitedValve).tx = 0
sesh.sb.strategy = fixedConnMapping sesh.sb.strategy = FIXED_CONN_MAPPING
n, err := sesh.sb.send(data[:10], &conn) n, err := sesh.sb.send(data[:10], &connId)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -114,14 +63,14 @@ func TestSwitchboard_TxCredit(t *testing.T) {
t.Errorf("wanted to send %v, got %v", 10, n) t.Errorf("wanted to send %v, got %v", 10, n)
return return
} }
if *sesh.sb.valve.(*LimitedValve).tx != 10 { if *sesh.sb.Valve.(*LimitedValve).tx != 10 {
t.Error("tx credit didn't increase by 10") t.Error("tx credit didn't increase by 10")
} }
}) })
t.Run("uniform spread", func(t *testing.T) { t.Run("UNIFORM", func(t *testing.T) {
*sesh.sb.valve.(*LimitedValve).tx = 0 *sesh.sb.Valve.(*LimitedValve).tx = 0
sesh.sb.strategy = uniformSpread sesh.sb.strategy = UNIFORM_SPREAD
n, err := sesh.sb.send(data[:10], &conn) n, err := sesh.sb.send(data[:10], &connId)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -130,58 +79,8 @@ func TestSwitchboard_TxCredit(t *testing.T) {
t.Errorf("wanted to send %v, got %v", 10, n) t.Errorf("wanted to send %v, got %v", 10, n)
return return
} }
if *sesh.sb.valve.(*LimitedValve).tx != 10 { if *sesh.sb.Valve.(*LimitedValve).tx != 10 {
t.Error("tx credit didn't increase by 10") t.Error("tx credit didn't increase by 10")
} }
}) })
} }
func TestSwitchboard_CloseOnOneDisconn(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
conn0client, conn0server := connutil.AsyncPipe()
sesh.AddConnection(conn0client)
conn1client, _ := connutil.AsyncPipe()
sesh.AddConnection(conn1client)
conn0server.Close()
assert.Eventually(t, func() bool {
return sesh.IsClosed()
}, time.Second, 10*time.Millisecond, "session not closed after one conn is disconnected")
if _, err := conn1client.Write([]byte{0x00}); err == nil {
t.Error("the other conn is still connected")
return
}
}
func TestSwitchboard_ConnsCount(t *testing.T) {
seshConfig := SessionConfig{
Valve: MakeValve(1<<20, 1<<20),
}
sesh := MakeSession(0, seshConfig)
var wg sync.WaitGroup
for i := 0; i < 1000; i++ {
wg.Add(1)
go func() {
sesh.AddConnection(connutil.Discard())
wg.Done()
}()
}
wg.Wait()
if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 {
t.Error("connsCount incorrect")
}
sesh.sb.closeAll()
assert.Eventuallyf(t, func() bool {
return atomic.LoadUint32(&sesh.sb.connsCount) == 0
}, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount))
}

View File

@ -1,101 +1,241 @@
package server package server
import ( import (
"bytes"
"crypto" "crypto"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io"
"net"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/ecdh" "github.com/cbeuw/Cloak/internal/ecdh"
"github.com/cbeuw/Cloak/internal/util"
log "github.com/sirupsen/logrus"
) )
const appDataMaxLength = 16401 // ClientHello contains every field in a ClientHello message
type ClientHello struct {
handshakeType byte
length int
clientVersion []byte
random []byte
sessionIdLen int
sessionId []byte
cipherSuitesLen int
cipherSuites []byte
compressionMethodsLen int
compressionMethods []byte
extensionsLen int
extensions map[[2]byte][]byte
}
type TLS struct{} var u16 = binary.BigEndian.Uint16
var u32 = binary.BigEndian.Uint32
var ErrBadClientHello = errors.New("non (or malformed) ClientHello") func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed Extensions")
}
}()
pointer := 0
totalLen := len(input)
ret = make(map[[2]byte][]byte)
for pointer < totalLen {
var typ [2]byte
copy(typ[:], input[pointer:pointer+2])
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
data := input[pointer : pointer+length]
pointer += length
ret[typ] = data
}
return ret, err
}
func (TLS) String() string { return "TLS" } func parseKeyShare(input []byte) (ret []byte, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("malformed key_share")
}
}()
totalLen := int(u16(input[0:2]))
// 2 bytes "client key share length"
pointer := 2
for pointer < totalLen {
if bytes.Equal([]byte{0x00, 0x1d}, input[pointer:pointer+2]) {
// skip "key exchange length"
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
if length != 32 {
return nil, fmt.Errorf("key share length should be 32, instead of %v", length)
}
return input[pointer : pointer+length], nil
}
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
_ = input[pointer : pointer+length]
pointer += length
}
return nil, errors.New("x25519 does not exist")
}
func (TLS) processFirstPacket(clientHello []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) { // addRecordLayer adds record layer to data
ch, err := parseClientHello(clientHello) func addRecordLayer(input []byte, typ []byte, ver []byte) []byte {
if err != nil { length := make([]byte, 2)
log.Debug(err) binary.BigEndian.PutUint16(length, uint16(len(input)))
err = ErrBadClientHello ret := make([]byte, 5+len(input))
return copy(ret[0:1], typ)
copy(ret[1:3], ver)
copy(ret[3:5], length)
copy(ret[5:], input)
return ret
}
// parseClientHello parses everything on top of the TLS layer
// (including the record layer) into ClientHello type
func parseClientHello(data []byte) (ret *ClientHello, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed ClientHello")
}
}()
if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) {
return ret, errors.New("wrong TLS1.3 handshake magic bytes")
} }
fragments, err = TLS{}.unmarshalClientHello(ch, privateKey) peeled := make([]byte, len(data)-5)
if err != nil { copy(peeled, data[5:])
err = fmt.Errorf("failed to unmarshal ClientHello into authFragments: %v", err) pointer := 0
return // Handshake Type
handshakeType := peeled[pointer]
if handshakeType != 0x01 {
return ret, errors.New("Not a ClientHello")
}
pointer += 1
// Length
length := int(u32(append([]byte{0x00}, peeled[pointer:pointer+3]...)))
pointer += 3
if length != len(peeled[pointer:]) {
return ret, errors.New("Hello length doesn't match")
}
// Client Version
clientVersion := peeled[pointer : pointer+2]
pointer += 2
// Random
random := peeled[pointer : pointer+32]
pointer += 32
// Session ID
sessionIdLen := int(peeled[pointer])
pointer += 1
sessionId := peeled[pointer : pointer+sessionIdLen]
pointer += sessionIdLen
// Cipher Suites
cipherSuitesLen := int(u16(peeled[pointer : pointer+2]))
pointer += 2
cipherSuites := peeled[pointer : pointer+cipherSuitesLen]
pointer += cipherSuitesLen
// Compression Methods
compressionMethodsLen := int(peeled[pointer])
pointer += 1
compressionMethods := peeled[pointer : pointer+compressionMethodsLen]
pointer += compressionMethodsLen
// Extensions
extensionsLen := int(u16(peeled[pointer : pointer+2]))
pointer += 2
extensions, err := parseExtensions(peeled[pointer:])
ret = &ClientHello{
handshakeType,
length,
clientVersion,
random,
sessionIdLen,
sessionId,
cipherSuitesLen,
cipherSuites,
compressionMethodsLen,
compressionMethods,
extensionsLen,
extensions,
} }
respond = TLS{}.makeResponder(ch.sessionId, fragments.sharedSecret)
return return
} }
func (TLS) makeResponder(clientHelloSessionId []byte, sharedSecret [32]byte) Responder { func composeServerHello(sessionId []byte, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
respond := func(originalConn net.Conn, sessionKey [32]byte, randSource io.Reader) (preparedConn net.Conn, err error) { nonce := make([]byte, 12)
// the cert length needs to be the same for all handshakes belonging to the same session rand.Read(nonce)
// we can use sessionKey as a seed here to ensure consistency
possibleCertLengths := []int{42, 27, 68, 59, 36, 44, 46}
cert := make([]byte, possibleCertLengths[common.RandInt(len(possibleCertLengths))])
common.RandRead(randSource, cert)
var nonce [12]byte encryptedKey, err := util.AESGCMEncrypt(nonce, sharedSecret, sessionKey) // 32 + 16 = 48 bytes
common.RandRead(randSource, nonce[:])
encryptedSessionKey, err := common.AESGCMEncrypt(nonce[:], sharedSecret[:], sessionKey[:])
if err != nil { if err != nil {
return return nil, err
} }
var encryptedSessionKeyArr [48]byte
copy(encryptedSessionKeyArr[:], encryptedSessionKey)
reply := composeReply(clientHelloSessionId, nonce, encryptedSessionKeyArr, cert) var serverHello [11][]byte
_, err = originalConn.Write(reply) serverHello[0] = []byte{0x02} // handshake type
if err != nil { serverHello[1] = []byte{0x00, 0x00, 0x76} // length 77
err = fmt.Errorf("failed to write TLS reply: %v", err) serverHello[2] = []byte{0x03, 0x03} // server version
originalConn.Close() serverHello[3] = append(nonce[0:12], encryptedKey[0:20]...) // random 32 bytes
return serverHello[4] = []byte{0x20} // session id length 32
serverHello[5] = sessionId // session id
serverHello[6] = []byte{0xc0, 0x30} // cipher suite TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
serverHello[7] = []byte{0x00} // compression method null
serverHello[8] = []byte{0x00, 0x2e} // extensions length 46
keyShare, _ := hex.DecodeString("00330024001d0020")
keyExchange := make([]byte, 32)
copy(keyExchange, encryptedKey[20:48])
rand.Read(keyExchange[28:32])
serverHello[9] = append(keyShare, keyExchange...)
serverHello[10], _ = hex.DecodeString("002b00020304")
var ret []byte
for _, s := range serverHello {
ret = append(ret, s...)
} }
preparedConn = common.NewTLSConn(originalConn) return ret, nil
return
}
return respond
} }
func (TLS) unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (fragments authFragments, err error) { // composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages
copy(fragments.randPubKey[:], ch.random) // together with their respective record layers into one byte slice.
ephPub, ok := ecdh.Unmarshal(fragments.randPubKey[:]) func composeReply(ch *ClientHello, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
TLS12 := []byte{0x03, 0x03}
sh, err := composeServerHello(ch.sessionId, sharedSecret, sessionKey)
if err != nil {
return nil, err
}
shBytes := addRecordLayer(sh, []byte{0x16}, TLS12)
ccsBytes := addRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
cert := make([]byte, 68) // TODO: add some different lengths maybe?
rand.Read(cert)
encryptedCertBytes := addRecordLayer(cert, []byte{0x17}, TLS12)
ret := append(shBytes, ccsBytes...)
ret = append(ret, encryptedCertBytes...)
return ret, nil
}
func unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (ai authenticationInfo, err error) {
ephPub, ok := ecdh.Unmarshal(ch.random)
if !ok { if !ok {
err = ErrInvalidPubKey err = ErrInvalidPubKey
return return
} }
var sharedSecret []byte ai.nonce = ch.random[:12]
sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub)
if err != nil {
return
}
copy(fragments.sharedSecret[:], sharedSecret) ai.sharedSecret = ecdh.GenerateSharedSecret(staticPv, ephPub)
var keyShare []byte var keyShare []byte
keyShare, err = parseKeyShare(ch.extensions[[2]byte{0x00, 0x33}]) keyShare, err = parseKeyShare(ch.extensions[[2]byte{0x00, 0x33}])
if err != nil { if err != nil {
return return
} }
ctxTag := append(ch.sessionId, keyShare...) ai.ciphertextWithTag = append(ch.sessionId, keyShare...)
if len(ctxTag) != 64 { if len(ai.ciphertextWithTag) != 64 {
err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(ctxTag)) err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(ai.ciphertextWithTag))
return return
} }
copy(fragments.ciphertextWithTag[:], ctxTag)
return return
} }

View File

@ -1,202 +0,0 @@
package server
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/cbeuw/Cloak/internal/common"
)
// ClientHello contains every field in a ClientHello message
type ClientHello struct {
handshakeType byte
length int
clientVersion []byte
random []byte
sessionIdLen int
sessionId []byte
cipherSuitesLen int
cipherSuites []byte
compressionMethodsLen int
compressionMethods []byte
extensionsLen int
extensions map[[2]byte][]byte
}
var u16 = binary.BigEndian.Uint16
var u32 = binary.BigEndian.Uint32
func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed Extensions")
}
}()
pointer := 0
totalLen := len(input)
ret = make(map[[2]byte][]byte)
for pointer < totalLen {
var typ [2]byte
copy(typ[:], input[pointer:pointer+2])
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
data := input[pointer : pointer+length]
pointer += length
ret[typ] = data
}
return ret, err
}
func parseKeyShare(input []byte) (ret []byte, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("malformed key_share")
}
}()
totalLen := int(u16(input[0:2]))
// 2 bytes "client key share length"
pointer := 2
for pointer < totalLen {
if bytes.Equal([]byte{0x00, 0x1d}, input[pointer:pointer+2]) {
// skip "key exchange length"
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
if length != 32 {
return nil, fmt.Errorf("key share length should be 32, instead of %v", length)
}
return input[pointer : pointer+length], nil
}
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
_ = input[pointer : pointer+length]
pointer += length
}
return nil, errors.New("x25519 does not exist")
}
// addRecordLayer adds record layer to data
func addRecordLayer(input []byte, typ []byte, ver []byte) []byte {
length := make([]byte, 2)
binary.BigEndian.PutUint16(length, uint16(len(input)))
ret := make([]byte, 5+len(input))
copy(ret[0:1], typ)
copy(ret[1:3], ver)
copy(ret[3:5], length)
copy(ret[5:], input)
return ret
}
// parseClientHello parses everything on top of the TLS layer
// (including the record layer) into ClientHello type
func parseClientHello(data []byte) (ret *ClientHello, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed ClientHello")
}
}()
if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) {
return ret, errors.New("wrong TLS1.3 handshake magic bytes")
}
peeled := make([]byte, len(data)-5)
copy(peeled, data[5:])
pointer := 0
// Handshake Type
handshakeType := peeled[pointer]
if handshakeType != 0x01 {
return ret, errors.New("Not a ClientHello")
}
pointer += 1
// Length
length := int(u32(append([]byte{0x00}, peeled[pointer:pointer+3]...)))
pointer += 3
if length != len(peeled[pointer:]) {
return ret, errors.New("Hello length doesn't match")
}
// Client Version
clientVersion := peeled[pointer : pointer+2]
pointer += 2
// Random
random := peeled[pointer : pointer+32]
pointer += 32
// Session ID
sessionIdLen := int(peeled[pointer])
pointer += 1
sessionId := peeled[pointer : pointer+sessionIdLen]
pointer += sessionIdLen
// Cipher Suites
cipherSuitesLen := int(u16(peeled[pointer : pointer+2]))
pointer += 2
cipherSuites := peeled[pointer : pointer+cipherSuitesLen]
pointer += cipherSuitesLen
// Compression Methods
compressionMethodsLen := int(peeled[pointer])
pointer += 1
compressionMethods := peeled[pointer : pointer+compressionMethodsLen]
pointer += compressionMethodsLen
// Extensions
extensionsLen := int(u16(peeled[pointer : pointer+2]))
pointer += 2
extensions, err := parseExtensions(peeled[pointer:])
ret = &ClientHello{
handshakeType,
length,
clientVersion,
random,
sessionIdLen,
sessionId,
cipherSuitesLen,
cipherSuites,
compressionMethodsLen,
compressionMethods,
extensionsLen,
extensions,
}
return
}
func composeServerHello(sessionId []byte, nonce [12]byte, encryptedSessionKeyWithTag [48]byte) []byte {
var serverHello [11][]byte
serverHello[0] = []byte{0x02} // handshake type
serverHello[1] = []byte{0x00, 0x00, 0x76} // length 118
serverHello[2] = []byte{0x03, 0x03} // server version
serverHello[3] = append(nonce[0:12], encryptedSessionKeyWithTag[0:20]...) // random 32 bytes
serverHello[4] = []byte{0x20} // session id length 32
serverHello[5] = sessionId // session id
serverHello[6] = []byte{0x13, 0x02} // cipher suite TLS_AES_256_GCM_SHA384
serverHello[7] = []byte{0x00} // compression method null
serverHello[8] = []byte{0x00, 0x2e} // extensions length 46
keyShare := []byte{0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20}
keyExchange := make([]byte, 32)
copy(keyExchange, encryptedSessionKeyWithTag[20:48])
common.CryptoRandRead(keyExchange[28:32])
serverHello[9] = append(keyShare, keyExchange...)
serverHello[10] = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04} // supported versions
var ret []byte
for _, s := range serverHello {
ret = append(ret, s...)
}
return ret
}
// composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages
// together with their respective record layers into one byte slice.
func composeReply(clientHelloSessionId []byte, nonce [12]byte, encryptedSessionKeyWithTag [48]byte, cert []byte) []byte {
TLS12 := []byte{0x03, 0x03}
sh := composeServerHello(clientHelloSessionId, nonce, encryptedSessionKeyWithTag)
shBytes := addRecordLayer(sh, []byte{0x16}, TLS12)
ccsBytes := addRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
encryptedCertBytes := addRecordLayer(cert, []byte{0x17}, TLS12)
ret := append(shBytes, ccsBytes...)
ret = append(ret, encryptedCertBytes...)
return ret
}

View File

@ -1,9 +1,8 @@
package server package server
import ( import (
"sync"
"github.com/cbeuw/Cloak/internal/server/usermanager" "github.com/cbeuw/Cloak/internal/server/usermanager"
"sync"
mux "github.com/cbeuw/Cloak/internal/multiplex" mux "github.com/cbeuw/Cloak/internal/multiplex"
) )
@ -40,7 +39,7 @@ func (u *ActiveUser) CloseSession(sessionID uint32, reason string) {
// GetSession returns the reference to an existing session, or if one such session doesn't exist, it queries // GetSession returns the reference to an existing session, or if one such session doesn't exist, it queries
// the UserManager for the authorisation for a new session. If a new session is allowed, it creates this new session // the UserManager for the authorisation for a new session. If a new session is allowed, it creates this new session
// and returns its reference // and returns its reference
func (u *ActiveUser) GetSession(sessionID uint32, config mux.SessionConfig) (sesh *mux.Session, existing bool, err error) { func (u *ActiveUser) GetSession(sessionID uint32, config *mux.SessionConfig) (sesh *mux.Session, existing bool, err error) {
u.sessionsM.Lock() u.sessionsM.Lock()
defer u.sessionsM.Unlock() defer u.sessionsM.Unlock()
if sesh = u.sessions[sessionID]; sesh != nil { if sesh = u.sessions[sessionID]; sesh != nil {

View File

@ -1,37 +1,17 @@
package server package server
import ( import (
"crypto/rand"
"encoding/base64" "encoding/base64"
"io/ioutil"
"os"
"testing"
"github.com/cbeuw/Cloak/internal/common"
mux "github.com/cbeuw/Cloak/internal/multiplex" mux "github.com/cbeuw/Cloak/internal/multiplex"
"github.com/cbeuw/Cloak/internal/server/usermanager" "github.com/cbeuw/Cloak/internal/server/usermanager"
"os"
"testing"
) )
func getSeshConfig(unordered bool) mux.SessionConfig {
var sessionKey [32]byte
rand.Read(sessionKey[:])
obfuscator, _ := mux.MakeObfuscator(0x00, sessionKey)
seshConfig := mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
Unordered: unordered,
}
return seshConfig
}
func TestActiveUser_Bypass(t *testing.T) { func TestActiveUser_Bypass(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info") manager, err := usermanager.MakeLocalManager(MOCK_DB_NAME)
defer os.Remove(tmpDB.Name())
manager, err := usermanager.MakeLocalManager(tmpDB.Name(), common.RealWorldState)
if err != nil { if err != nil {
t.Fatal("failed to make local manager", err) t.Error("failed to make local manager", err)
} }
panel := MakeUserPanel(manager) panel := MakeUserPanel(manager)
UID, _ := base64.StdEncoding.DecodeString("u97xvcc5YoQA8obCyt9q/w==") UID, _ := base64.StdEncoding.DecodeString("u97xvcc5YoQA8obCyt9q/w==")
@ -39,85 +19,92 @@ func TestActiveUser_Bypass(t *testing.T) {
var sesh0 *mux.Session var sesh0 *mux.Session
var existing bool var existing bool
var sesh1 *mux.Session var sesh1 *mux.Session
t.Run("get first session", func(t *testing.T) {
// get first session sesh0, existing, err = user.GetSession(0, &mux.SessionConfig{})
sesh0, existing, err = user.GetSession(0, getSeshConfig(false))
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
} }
if existing { if existing {
t.Fatal("get first session: first session returned as existing") t.Error("first session returned as existing")
} }
if sesh0 == nil { if sesh0 == nil {
t.Fatal("get first session: no session returned") t.Error("no session returned")
} }
})
// get first session again t.Run("get first session again", func(t *testing.T) {
seshx, existing, err := user.GetSession(0, mux.SessionConfig{}) seshx, existing, err := user.GetSession(0, &mux.SessionConfig{})
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
} }
if !existing { if !existing {
t.Fatal("get first session again: first session get again returned as not existing") t.Error("first session get again returned as not existing")
} }
if seshx == nil { if seshx == nil {
t.Fatal("get first session again: no session returned") t.Error("no session returned")
} }
if seshx != sesh0 { if seshx != sesh0 {
t.Fatal("returned a different instance") t.Error("returned a different instance")
} }
})
// get second session t.Run("get second session", func(t *testing.T) {
sesh1, existing, err = user.GetSession(1, getSeshConfig(false)) sesh1, existing, err = user.GetSession(1, &mux.SessionConfig{})
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
} }
if existing { if existing {
t.Fatal("get second session: second session returned as existing") t.Error("second session returned as existing")
} }
if sesh1 == nil { if sesh0 == nil {
t.Fatal("get second session: no session returned") t.Error("no session returned")
} }
})
t.Run("number of sessions", func(t *testing.T) {
if user.NumSession() != 2 { if user.NumSession() != 2 {
t.Fatal("number of session is not 2") t.Error("number of session is not 2")
} }
})
t.Run("delete a session", func(t *testing.T) {
user.CloseSession(0, "") user.CloseSession(0, "")
if user.NumSession() != 1 { if user.NumSession() != 1 {
t.Fatal("number of session is not 1 after deleting one") t.Error("number of session is not 1 after deleting one")
} }
if !sesh0.IsClosed() { if !sesh0.IsClosed() {
t.Fatal("session not closed after deletion") t.Error("session not closed after deletion")
} }
})
t.Run("close all sessions", func(t *testing.T) {
user.closeAllSessions("") user.closeAllSessions("")
if !sesh1.IsClosed() { if !sesh1.IsClosed() {
t.Fatal("session not closed after user termination") t.Error("session not closed after user termination")
} }
})
// get session again after termination t.Run("get session again after termination", func(t *testing.T) {
seshy, existing, err := user.GetSession(0, getSeshConfig(false)) seshx, existing, err := user.GetSession(0, &mux.SessionConfig{})
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
} }
if existing { if existing {
t.Fatal("get session again after termination: session returned as existing") t.Error("session returned as existing")
} }
if seshy == nil { if seshx == nil {
t.Fatal("get session again after termination: no session returned") t.Error("no session returned")
} }
if seshy == sesh0 || seshy == sesh1 { if seshx == sesh0 || seshx == sesh1 {
t.Fatal("get session after termination returned the same instance") t.Error("get session after termination returned the same instance")
} }
})
t.Run("delete last session", func(t *testing.T) {
user.CloseSession(0, "") user.CloseSession(0, "")
if panel.isActive(user.arrUID[:]) { if panel.isActive(user.arrUID[:]) {
t.Fatal("user still active after last session deleted") t.Error("user still active after last session deleted")
} }
})
err = manager.Close() err = manager.Close()
if err != nil { if err != nil {
t.Fatal("failed to close localmanager", err) t.Error("failed to close localmanager", err)
}
err = os.Remove(MOCK_DB_NAME)
if err != nil {
t.Error("failed to delete mockdb", err)
} }
} }

View File

@ -1,14 +1,18 @@
package server package server
import ( import (
"bufio"
"bytes" "bytes"
"crypto/rand"
"encoding/base64"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/cbeuw/Cloak/internal/util"
"net"
"net/http"
"time" "time"
"github.com/cbeuw/Cloak/internal/common"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -21,22 +25,25 @@ type ClientInfo struct {
Transport Transport Transport Transport
} }
type authFragments struct { type authenticationInfo struct {
sharedSecret [32]byte sharedSecret []byte
randPubKey [32]byte nonce []byte
ciphertextWithTag [64]byte ciphertextWithTag []byte
} }
const ( const (
UNORDERED_FLAG = 0x01 // 0000 0001 UNORDERED_FLAG = 0x01 // 0000 0001
) )
var ErrInvalidPubKey = errors.New("public key has invalid format")
var ErrCiphertextLength = errors.New("ciphertext has the wrong length")
var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window") var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window")
var ErrUnreconisedProtocol = errors.New("unreconised protocol")
// decryptClientInfo checks if a the authFragments are valid. It doesn't check if the UID is authorised // touchStone checks if a the authenticationInfo are valid. It doesn't check if the UID is authorised
func decryptClientInfo(fragments authFragments, serverTime time.Time) (info ClientInfo, err error) { func touchStone(ai authenticationInfo, now func() time.Time) (info ClientInfo, err error) {
var plaintext []byte var plaintext []byte
plaintext, err = common.AESGCMDecrypt(fragments.randPubKey[0:12], fragments.sharedSecret[:], fragments.ciphertextWithTag[:]) plaintext, err = util.AESGCMDecrypt(ai.nonce, ai.sharedSecret, ai.ciphertextWithTag)
if err != nil { if err != nil {
return return
} }
@ -51,7 +58,8 @@ func decryptClientInfo(fragments authFragments, serverTime time.Time) (info Clie
timestamp := int64(binary.BigEndian.Uint64(plaintext[29:37])) timestamp := int64(binary.BigEndian.Uint64(plaintext[29:37]))
clientTime := time.Unix(timestamp, 0) clientTime := time.Unix(timestamp, 0)
if !(clientTime.After(serverTime.Add(-timestampTolerance)) && clientTime.Before(serverTime.Add(timestampTolerance))) { serverTime := now()
if !(clientTime.After(serverTime.Truncate(TIMESTAMP_TOLERANCE)) && clientTime.Before(serverTime.Add(TIMESTAMP_TOLERANCE))) {
err = fmt.Errorf("%v: received timestamp %v", ErrTimestampOutOfWindow, timestamp) err = fmt.Errorf("%v: received timestamp %v", ErrTimestampOutOfWindow, timestamp)
return return
} }
@ -59,31 +67,112 @@ func decryptClientInfo(fragments authFragments, serverTime time.Time) (info Clie
return return
} }
var ErrBadClientHello = errors.New("non (or malformed) ClientHello")
var ErrReplay = errors.New("duplicate random") var ErrReplay = errors.New("duplicate random")
var ErrBadProxyMethod = errors.New("invalid proxy method") var ErrBadProxyMethod = errors.New("invalid proxy method")
var ErrBadDecryption = errors.New("decryption/authentication failure")
// AuthFirstPacket checks if the first packet of data is ClientHello or HTTP GET, and checks if it was from a Cloak client // PrepareConnection checks if the first packet of data is ClientHello or HTTP GET, and checks if it was from a Cloak client
// if it is from a Cloak client, it returns the ClientInfo with the decrypted fields. It doesn't check if the user // if it is from a Cloak client, it returns the ClientInfo with the decrypted fields. It doesn't check if the user
// is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with // is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with
// the handshake // the handshake
func AuthFirstPacket(firstPacket []byte, transport Transport, sta *State) (info ClientInfo, finisher Responder, err error) { func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) (net.Conn, error), err error) {
fragments, finisher, err := transport.processFirstPacket(firstPacket, sta.StaticPv) var transport Transport
var ai authenticationInfo
switch firstPacket[0] {
case 0x47:
transport = WebSocket{}
var req *http.Request
req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(firstPacket)))
if err != nil { if err != nil {
err = fmt.Errorf("failed to parse first HTTP GET: %v", err)
return
}
var hiddenData []byte
hiddenData, err = base64.StdEncoding.DecodeString(req.Header.Get("hidden"))
ai, err = unmarshalHidden(hiddenData, sta.staticPv)
if err != nil {
err = fmt.Errorf("failed to unmarshal hidden data from WS into authenticationInfo: %v", err)
return return
} }
if sta.registerRandom(fragments.randPubKey) { finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
handler := newWsHandshakeHandler()
http.Serve(newWsAcceptor(conn, firstPacket), handler)
<-handler.finished
preparedConn = handler.conn
nonce := make([]byte, 12)
rand.Read(nonce)
// reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag]
encryptedKey, err := util.AESGCMEncrypt(nonce, ai.sharedSecret, sessionKey) // 32 + 16 = 48 bytes
if err != nil {
err = fmt.Errorf("failed to encrypt reply: %v", err)
return
}
reply := append(nonce, encryptedKey...)
_, err = preparedConn.Write(reply)
if err != nil {
err = fmt.Errorf("failed to write reply: %v", err)
go preparedConn.Close()
return
}
return
}
case 0x16:
transport = TLS{}
var ch *ClientHello
ch, err = parseClientHello(firstPacket)
if err != nil {
log.Debug(err)
err = ErrBadClientHello
return
}
if sta.registerRandom(ch.random) {
err = ErrReplay err = ErrReplay
return return
} }
info, err = decryptClientInfo(fragments, sta.WorldState.Now().UTC()) ai, err = unmarshalClientHello(ch, sta.staticPv)
if err != nil {
err = fmt.Errorf("failed to unmarshal ClientHello into authenticationInfo: %v", err)
return
}
finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
preparedConn = conn
reply, err := composeReply(ch, ai.sharedSecret, sessionKey)
if err != nil {
err = fmt.Errorf("failed to compose TLS reply: %v", err)
return
}
_, err = preparedConn.Write(reply)
if err != nil {
err = fmt.Errorf("failed to write TLS reply: %v", err)
go preparedConn.Close()
return
}
return
}
default:
err = ErrUnreconisedProtocol
return
}
info, err = touchStone(ai, sta.Now)
if err != nil { if err != nil {
log.Debug(err) log.Debug(err)
err = fmt.Errorf("%w: %v", ErrBadDecryption, err) err = fmt.Errorf("transport %v in correct format but not Cloak: %v", transport, err)
return return
} }
info.Transport = transport info.Transport = transport
if _, ok := sta.ProxyBook[info.ProxyMethod]; !ok {
err = ErrBadProxyMethod
return
}
return return
} }

View File

@ -3,15 +3,12 @@ package server
import ( import (
"crypto" "crypto"
"encoding/hex" "encoding/hex"
"fmt" "github.com/cbeuw/Cloak/internal/ecdh"
"testing" "testing"
"time" "time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/ecdh"
) )
func TestDecryptClientInfo(t *testing.T) { func TestTouchStone(t *testing.T) {
pvBytes, _ := hex.DecodeString("10de5a3c4a4d04efafc3e06d1506363a72bd6d053baef123e6a9a79a0c04b547") pvBytes, _ := hex.DecodeString("10de5a3c4a4d04efafc3e06d1506363a72bd6d053baef123e6a9a79a0c04b547")
p, _ := ecdh.Unmarshal(pvBytes) p, _ := ecdh.Unmarshal(pvBytes)
staticPv := p.(crypto.PrivateKey) staticPv := p.(crypto.PrivateKey)
@ -19,14 +16,14 @@ func TestDecryptClientInfo(t *testing.T) {
t.Run("correct time", func(t *testing.T) { t.Run("correct time", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
ch, _ := parseClientHello(chBytes) ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv) ai, err := unmarshalClientHello(ch, staticPv)
if err != nil { if err != nil {
t.Errorf("expecting no error, got %v", err) t.Errorf("expecting no error, got %v", err)
return return
} }
nineSixSix := time.Unix(1565998966, 0) nineSixSix := func() time.Time { return time.Unix(1565998966, 0) }
cinfo, err := decryptClientInfo(ai, nineSixSix) cinfo, err := touchStone(ai, nineSixSix)
if err != nil { if err != nil {
t.Errorf("expecting no error, got %v", err) t.Errorf("expecting no error, got %v", err)
return return
@ -35,40 +32,17 @@ func TestDecryptClientInfo(t *testing.T) {
t.Errorf("expecting session id 3710878841, got %v", cinfo.SessionId) t.Errorf("expecting session id 3710878841, got %v", cinfo.SessionId)
} }
}) })
t.Run("roughly correct time", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
nineSixSixP50 := time.Unix(1565998966, 0).Add(50)
_, err = decryptClientInfo(ai, nineSixSixP50)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
nineSixSixM50 := time.Unix(1565998966, 0).Add(-50)
_, err = decryptClientInfo(ai, nineSixSixM50)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
})
t.Run("over interval", func(t *testing.T) { t.Run("over interval", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
ch, _ := parseClientHello(chBytes) ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv) ai, err := unmarshalClientHello(ch, staticPv)
if err != nil { if err != nil {
t.Errorf("expecting no error, got %v", err) t.Errorf("expecting no error, got %v", err)
return return
} }
nineSixSixOver := time.Unix(1565998966, 0).Add(timestampTolerance + 10) nineSixSixOver := func() time.Time { return time.Unix(1565998966, 0).Add(TIMESTAMP_TOLERANCE + 10) }
_, err = decryptClientInfo(ai, nineSixSixOver) _, err = touchStone(ai, nineSixSixOver)
if err == nil { if err == nil {
t.Errorf("expecting %v, got %v", ErrTimestampOutOfWindow, err) t.Errorf("expecting %v, got %v", ErrTimestampOutOfWindow, err)
return return
@ -77,14 +51,14 @@ func TestDecryptClientInfo(t *testing.T) {
t.Run("under interval", func(t *testing.T) { t.Run("under interval", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
ch, _ := parseClientHello(chBytes) ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv) ai, err := unmarshalClientHello(ch, staticPv)
if err != nil { if err != nil {
t.Errorf("expecting no error, got %v", err) t.Errorf("expecting no error, got %v", err)
return return
} }
nineSixSixUnder := time.Unix(1565998966, 0).Add(-(timestampTolerance + 10)) nineSixSixUnder := func() time.Time { return time.Unix(1565998966, 0).Add(TIMESTAMP_TOLERANCE - 10) }
_, err = decryptClientInfo(ai, nineSixSixUnder) _, err = touchStone(ai, nineSixSixUnder)
if err == nil { if err == nil {
t.Errorf("expecting %v, got %v", ErrTimestampOutOfWindow, err) t.Errorf("expecting %v, got %v", ErrTimestampOutOfWindow, err)
return return
@ -93,14 +67,14 @@ func TestDecryptClientInfo(t *testing.T) {
t.Run("not cloak psk", func(t *testing.T) { t.Run("not cloak psk", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010246010002420303794ae79c6db7a31e67e2ce91b8afcb82995ae79ad1d0dc885f933e4193bf95cd208abd7a70f3b82cc31c02f1c2b94ba74d5222a66695a5cf92a366421d7f5eb9530022fafa130113021303c02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a010001d75a5a00000000001e001c0000196c68332e676f6f676c6575736572636f6e74656e742e636f6d00170000ff01000100000a000a0008baba001d00170018000b00020100002300000010000e000c02683208687474702f312e31000500050100000000000d00140012040308040401050308050501080606010201001200000033002b0029baba000100001d002074bfe93336c364b43cf0879d997b2e11dc97068b86fc90174e0f2bcea1d4ed1c002d00020101002b000b0ababa0304030303020301001b00030200029a9a0001000029010500e000da00d1f6c0918f865390ae3ca33c77f61a1974cb4533456071b214ec018d17dc22845f2f72cf1dba48f9cdc0758803002dda9b964fad5522e82442af7cbbe242241e39233386f2383bce3ced8e16b1ae3f0ef52a706f58e1e6a1bca0cd3b3a2a4c4cb738770b01b56bf3e73c472bf4fb238cab510aa78f8427a3ca99f741aa433f548be460705f43a3abe878cec6ee3158c129406910b93e798e8a7aaffc2e7ff7b8fd872778d3687a0beaa1452fe7ec418070d537344b64d09f6edd053346ff9c9678eef6b8886882aba81d4be11d9df653de35659f93a22ac39399e3ba400021204e22b73261693967a9216fe4a3b004571c53f316309e76671a18d78931b5b072") chBytes, _ := hex.DecodeString("1603010246010002420303794ae79c6db7a31e67e2ce91b8afcb82995ae79ad1d0dc885f933e4193bf95cd208abd7a70f3b82cc31c02f1c2b94ba74d5222a66695a5cf92a366421d7f5eb9530022fafa130113021303c02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a010001d75a5a00000000001e001c0000196c68332e676f6f676c6575736572636f6e74656e742e636f6d00170000ff01000100000a000a0008baba001d00170018000b00020100002300000010000e000c02683208687474702f312e31000500050100000000000d00140012040308040401050308050501080606010201001200000033002b0029baba000100001d002074bfe93336c364b43cf0879d997b2e11dc97068b86fc90174e0f2bcea1d4ed1c002d00020101002b000b0ababa0304030303020301001b00030200029a9a0001000029010500e000da00d1f6c0918f865390ae3ca33c77f61a1974cb4533456071b214ec018d17dc22845f2f72cf1dba48f9cdc0758803002dda9b964fad5522e82442af7cbbe242241e39233386f2383bce3ced8e16b1ae3f0ef52a706f58e1e6a1bca0cd3b3a2a4c4cb738770b01b56bf3e73c472bf4fb238cab510aa78f8427a3ca99f741aa433f548be460705f43a3abe878cec6ee3158c129406910b93e798e8a7aaffc2e7ff7b8fd872778d3687a0beaa1452fe7ec418070d537344b64d09f6edd053346ff9c9678eef6b8886882aba81d4be11d9df653de35659f93a22ac39399e3ba400021204e22b73261693967a9216fe4a3b004571c53f316309e76671a18d78931b5b072")
ch, _ := parseClientHello(chBytes) ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv) ai, err := unmarshalClientHello(ch, staticPv)
if err != nil { if err != nil {
t.Errorf("expecting no error, got %v", err) t.Errorf("expecting no error, got %v", err)
return return
} }
fiveOSix := time.Unix(1565999506, 0) fiveOSix := func() time.Time { return time.Unix(1565999506, 0) }
cinfo, err := decryptClientInfo(ai, fiveOSix) cinfo, err := touchStone(ai, fiveOSix)
if err == nil { if err == nil {
t.Errorf("not a cloak, got nil error and cinfo %v", cinfo) t.Errorf("not a cloak, got nil error and cinfo %v", cinfo)
return return
@ -109,14 +83,14 @@ func TestDecryptClientInfo(t *testing.T) {
t.Run("not cloak no psk", func(t *testing.T) { t.Run("not cloak no psk", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303eae4c204a867390a758fcff3afa5803cac3e07011cf0c9f3befc1267445aabee20fc398df698113617f8161cbcb89534efa892088a6c5e49246534e05f790ea36f00220a0a130113021303c02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a010001910a0a000000000014001200000f63646e2e62697a69626c652e636f6d00170000ff01000100000a000a0008caca001d00170018000b00020100002300000010000e000c02683208687474702f312e31000500050100000000000d00140012040308040401050308050501080606010201001200000033002b0029caca000100001d00204c8f1563fb70c261bc0c32c1b568b8d02fab25f4094711e7868b1712751dc754002d00020101002b000b0a2a2a0304030303020301001b00030200026a6a000100001500c9000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") chBytes, _ := hex.DecodeString("1603010200010001fc0303eae4c204a867390a758fcff3afa5803cac3e07011cf0c9f3befc1267445aabee20fc398df698113617f8161cbcb89534efa892088a6c5e49246534e05f790ea36f00220a0a130113021303c02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a010001910a0a000000000014001200000f63646e2e62697a69626c652e636f6d00170000ff01000100000a000a0008caca001d00170018000b00020100002300000010000e000c02683208687474702f312e31000500050100000000000d00140012040308040401050308050501080606010201001200000033002b0029caca000100001d00204c8f1563fb70c261bc0c32c1b568b8d02fab25f4094711e7868b1712751dc754002d00020101002b000b0a2a2a0304030303020301001b00030200026a6a000100001500c9000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
ch, _ := parseClientHello(chBytes) ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv) ai, err := unmarshalClientHello(ch, staticPv)
if err != nil { if err != nil {
t.Errorf("expecting no error, got %v", err) t.Errorf("expecting no error, got %v", err)
return return
} }
sixOneFive := time.Unix(1565999615, 0) sixOneFive := func() time.Time { return time.Unix(1565999615, 0) }
cinfo, err := decryptClientInfo(ai, sixOneFive) cinfo, err := touchStone(ai, sixOneFive)
if err == nil { if err == nil {
t.Errorf("not a cloak, got nil error and cinfo %v", cinfo) t.Errorf("not a cloak, got nil error and cinfo %v", cinfo)
return return
@ -124,73 +98,3 @@ func TestDecryptClientInfo(t *testing.T) {
}) })
} }
func TestAuthFirstPacket(t *testing.T) {
pvBytes, _ := hex.DecodeString("10de5a3c4a4d04efafc3e06d1506363a72bd6d053baef123e6a9a79a0c04b547")
p, _ := ecdh.Unmarshal(pvBytes)
getNewState := func() *State {
sta, _ := InitState(RawConfig{}, common.WorldOfTime(time.Unix(1565998966, 0)))
sta.StaticPv = p.(crypto.PrivateKey)
sta.ProxyBook["shadowsocks"] = nil
return sta
}
t.Run("TLS correct", func(t *testing.T) {
sta := getNewState()
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
info, _, err := AuthFirstPacket(chBytes, TLS{}, sta)
if err != nil {
t.Errorf("failed to get client info: %v", err)
return
}
if info.SessionId != 3710878841 {
t.Error("failed to get correct session id")
return
}
if info.Transport.(fmt.Stringer).String() != "TLS" {
t.Errorf("wrong transport: %v", info.Transport)
return
}
})
t.Run("TLS correct but replay", func(t *testing.T) {
sta := getNewState()
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
_, _, err := AuthFirstPacket(chBytes, TLS{}, sta)
if err != nil {
t.Error("failed to prepare for the first time")
return
}
_, _, err = AuthFirstPacket(chBytes, TLS{}, sta)
if err != ErrReplay {
t.Errorf("failed to return ErrReplay, got %v instead", err)
return
}
})
t.Run("Websocket correct", func(t *testing.T) {
sta, _ := InitState(RawConfig{}, common.WorldOfTime(time.Unix(1584358419, 0)))
sta.StaticPv = p.(crypto.PrivateKey)
sta.ProxyBook["shadowsocks"] = nil
req := `GET / HTTP/1.1
Host: d2jkinvisak5y9.cloudfront.net:443
User-Agent: Go-http-client/1.1
Connection: Upgrade
Hidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ
Sec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==
Sec-WebSocket-Version: 13
Upgrade: websocket
`
info, _, err := AuthFirstPacket([]byte(req), WebSocket{}, sta)
if err != nil {
t.Errorf("failed to get client info: %v", err)
return
}
if info.Transport.(fmt.Stringer).String() != "WebSocket" {
t.Errorf("wrong transport: %v", info.Transport)
return
}
})
}

View File

@ -1,311 +0,0 @@
package server
import (
"bytes"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/http"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/server/usermanager"
mux "github.com/cbeuw/Cloak/internal/multiplex"
log "github.com/sirupsen/logrus"
)
var b64 = base64.StdEncoding.EncodeToString
const firstPacketSize = 3000
func Serve(l net.Listener, sta *State) {
waitDur := [10]time.Duration{
50 * time.Millisecond, 100 * time.Millisecond, 300 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second,
3 * time.Second, 5 * time.Second, 10 * time.Second, 15 * time.Second, 30 * time.Second}
fails := 0
for {
conn, err := l.Accept()
if err != nil {
log.Errorf("%v, retrying", err)
time.Sleep(waitDur[fails])
if fails < 9 {
fails++
}
continue
}
fails = 0
go dispatchConnection(conn, sta)
}
}
func connReadLine(conn net.Conn, buf []byte) (int, error) {
i := 0
for ; i < len(buf); i++ {
_, err := io.ReadFull(conn, buf[i:i+1])
if err != nil {
return i, err
}
if buf[i] == '\n' {
return i + 1, nil
}
}
return i, io.ErrShortBuffer
}
var ErrUnrecognisedProtocol = errors.New("unrecognised protocol")
func readFirstPacket(conn net.Conn, buf []byte, timeout time.Duration) (int, Transport, bool, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
defer conn.SetReadDeadline(time.Time{})
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
err = fmt.Errorf("read error after connection is established: %v", err)
conn.Close()
return 0, nil, false, err
}
// TODO: give the option to match the protocol with port
bufOffset := 1
var transport Transport
switch buf[0] {
case 0x16:
transport = TLS{}
recordLayerLength := 5
i, err := io.ReadFull(conn, buf[bufOffset:recordLayerLength])
bufOffset += i
if err != nil {
err = fmt.Errorf("read error after connection is established: %v", err)
conn.Close()
return bufOffset, transport, false, err
}
dataLength := int(binary.BigEndian.Uint16(buf[3:5]))
if dataLength+recordLayerLength > len(buf) {
return bufOffset, transport, true, io.ErrShortBuffer
}
i, err = io.ReadFull(conn, buf[recordLayerLength:dataLength+recordLayerLength])
bufOffset += i
if err != nil {
err = fmt.Errorf("read error after connection is established: %v", err)
conn.Close()
return bufOffset, transport, false, err
}
case 0x47:
transport = WebSocket{}
for {
i, err := connReadLine(conn, buf[bufOffset:])
line := buf[bufOffset : bufOffset+i]
bufOffset += i
if err != nil {
if err == io.ErrShortBuffer {
return bufOffset, transport, true, err
} else {
err = fmt.Errorf("error reading first packet: %v", err)
conn.Close()
return bufOffset, transport, false, err
}
}
if bytes.Equal(line, []byte("\r\n")) {
break
}
}
default:
return bufOffset, transport, true, ErrUnrecognisedProtocol
}
return bufOffset, transport, true, nil
}
func dispatchConnection(conn net.Conn, sta *State) {
var err error
buf := make([]byte, firstPacketSize)
i, transport, redirOnErr, err := readFirstPacket(conn, buf, 15*time.Second)
data := buf[:i]
goWeb := func() {
redirPort := sta.RedirPort
if redirPort == "" {
_, redirPort, _ = net.SplitHostPort(conn.LocalAddr().String())
}
webConn, err := sta.RedirDialer.Dial("tcp", net.JoinHostPort(sta.RedirHost.String(), redirPort))
if err != nil {
log.Errorf("Making connection to redirection server: %v", err)
return
}
_, err = webConn.Write(data)
if err != nil {
log.Error("Failed to send first packet to redirection server", err)
return
}
go common.Copy(webConn, conn)
go common.Copy(conn, webConn)
}
if err != nil {
log.WithField("remoteAddr", conn.RemoteAddr()).
Warnf("error reading first packet: %v", err)
if redirOnErr {
goWeb()
} else {
conn.Close()
}
return
}
ci, finishHandshake, err := AuthFirstPacket(data, transport, sta)
if err != nil {
log.WithFields(log.Fields{
"remoteAddr": conn.RemoteAddr(),
"UID": b64(ci.UID),
"sessionId": ci.SessionId,
"proxyMethod": ci.ProxyMethod,
"encryptionMethod": ci.EncryptionMethod,
}).Warn(err)
goWeb()
return
}
var sessionKey [32]byte
common.RandRead(sta.WorldState.Rand, sessionKey[:])
obfuscator, err := mux.MakeObfuscator(ci.EncryptionMethod, sessionKey)
if err != nil {
log.WithFields(log.Fields{
"remoteAddr": conn.RemoteAddr(),
"UID": b64(ci.UID),
"sessionId": ci.SessionId,
"proxyMethod": ci.ProxyMethod,
"encryptionMethod": ci.EncryptionMethod,
}).Error(err)
goWeb()
return
}
seshConfig := mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
Unordered: ci.Unordered,
MsgOnWireSizeLimit: appDataMaxLength,
}
// adminUID can use the server as normal with unlimited QoS credits. The adminUID is not
// added to the userinfo database. The distinction between going into the admin mode
// and normal proxy mode is that sessionID needs == 0 for admin mode
if len(sta.AdminUID) != 0 && bytes.Equal(ci.UID, sta.AdminUID) && ci.SessionId == 0 {
sesh := mux.MakeSession(0, seshConfig)
preparedConn, err := finishHandshake(conn, sessionKey, sta.WorldState.Rand)
if err != nil {
log.Error(err)
return
}
log.Trace("finished handshake")
sesh.AddConnection(preparedConn)
//TODO: Router could be nil in cnc mode
log.WithField("remoteAddr", preparedConn.RemoteAddr()).Info("New admin session")
err = http.Serve(sesh, usermanager.APIRouterOf(sta.Panel.Manager))
// http.Serve never returns with non-nil error
log.Error(err)
return
}
if _, ok := sta.ProxyBook[ci.ProxyMethod]; !ok {
log.WithFields(log.Fields{
"remoteAddr": conn.RemoteAddr(),
"UID": b64(ci.UID),
"sessionId": ci.SessionId,
"proxyMethod": ci.ProxyMethod,
"encryptionMethod": ci.EncryptionMethod,
}).Error(ErrBadProxyMethod)
goWeb()
return
}
var user *ActiveUser
if sta.IsBypass(ci.UID) {
user, err = sta.Panel.GetBypassUser(ci.UID)
} else {
user, err = sta.Panel.GetUser(ci.UID)
}
if err != nil {
log.WithFields(log.Fields{
"UID": b64(ci.UID),
"remoteAddr": conn.RemoteAddr(),
"error": err,
}).Warn("+1 unauthorised UID")
goWeb()
return
}
sesh, existing, err := user.GetSession(ci.SessionId, seshConfig)
if err != nil {
user.CloseSession(ci.SessionId, "")
log.Error(err)
return
}
preparedConn, err := finishHandshake(conn, sesh.GetSessionKey(), sta.WorldState.Rand)
if err != nil {
log.Error(err)
return
}
log.Trace("finished handshake")
sesh.AddConnection(preparedConn)
if !existing {
// if the session was newly made, we serve connections from the session streams to the proxy server
log.WithFields(log.Fields{
"UID": b64(ci.UID),
"sessionID": ci.SessionId,
}).Info("New session")
serveSession(sesh, ci, user, sta)
}
}
func serveSession(sesh *mux.Session, ci ClientInfo, user *ActiveUser, sta *State) error {
for {
newStream, err := sesh.Accept()
if err != nil {
if err == mux.ErrBrokenSession {
log.WithFields(log.Fields{
"UID": b64(ci.UID),
"sessionID": ci.SessionId,
"reason": sesh.TerminalMsg(),
}).Info("Session closed")
user.CloseSession(ci.SessionId, "")
return nil
} else {
log.Errorf("unhandled error on session.Accept(): %v", err)
continue
}
}
proxyAddr := sta.ProxyBook[ci.ProxyMethod]
localConn, err := sta.ProxyDialer.Dial(proxyAddr.Network(), proxyAddr.String())
if err != nil {
log.Errorf("Failed to connect to %v: %v", ci.ProxyMethod, err)
user.CloseSession(ci.SessionId, "Failed to connect to proxy server")
return err
}
log.Tracef("%v endpoint has been successfully connected", ci.ProxyMethod)
go func() {
if _, err := common.Copy(localConn, newStream); err != nil {
log.Tracef("copying stream to proxy server: %v", err)
}
}()
go func() {
if _, err := common.Copy(newStream, localConn); err != nil {
log.Tracef("copying proxy server to stream: %v", err)
}
}()
}
}

View File

@ -1,211 +0,0 @@
package server
import (
"encoding/hex"
"io"
"net"
"testing"
"time"
"github.com/cbeuw/connutil"
"github.com/stretchr/testify/assert"
)
type rfpReturnValue struct {
n int
transport Transport
redirOnErr bool
err error
}
const timeout = 500 * time.Millisecond
func TestReadFirstPacket(t *testing.T) {
rfp := func(conn net.Conn, buf []byte, retChan chan<- rfpReturnValue) {
ret := rfpReturnValue{}
ret.n, ret.transport, ret.redirOnErr, ret.err = readFirstPacket(conn, buf, timeout)
retChan <- ret
}
t.Run("Good TLS", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
local.Write(first)
ret := <-retChan
assert.Equal(t, len(first), ret.n)
assert.Equal(t, first, buf[:ret.n])
assert.IsType(t, TLS{}, ret.transport)
assert.NoError(t, ret.err)
})
t.Run("Good TLS but buf too small", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 10)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
local.Write(first)
ret := <-retChan
assert.Equal(t, io.ErrShortBuffer, ret.err)
assert.True(t, ret.redirOnErr)
assert.Equal(t, first[:ret.n], buf[:ret.n])
})
t.Run("Incomplete timeout", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("160301")
local.Write(first)
select {
case ret := <-retChan:
assert.Equal(t, len(first), ret.n)
assert.False(t, ret.redirOnErr)
assert.Error(t, ret.err)
case <-time.After(2 * timeout):
assert.Fail(t, "readFirstPacket should have timed out")
}
})
t.Run("Incomplete payload timeout", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("16030101010000")
local.Write(first)
select {
case ret := <-retChan:
assert.Equal(t, len(first), ret.n)
assert.False(t, ret.redirOnErr)
assert.Error(t, ret.err)
case <-time.After(2 * timeout):
assert.Fail(t, "readFirstPacket should have timed out")
}
})
t.Run("Good TLS staggered", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
local.Write(first[:100])
time.Sleep(timeout / 2)
local.Write(first[100:])
ret := <-retChan
assert.Equal(t, len(first), ret.n)
assert.Equal(t, first, buf[:ret.n])
assert.IsType(t, TLS{}, ret.transport)
assert.NoError(t, ret.err)
})
t.Run("TLS bad recordlayer length", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("160301ffff")
local.Write(first)
ret := <-retChan
assert.Equal(t, len(first), ret.n)
assert.Equal(t, first, buf[:ret.n])
assert.IsType(t, TLS{}, ret.transport)
assert.Equal(t, io.ErrShortBuffer, ret.err)
assert.True(t, ret.redirOnErr)
})
t.Run("Good WebSocket", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
reqStr := "GET / HTTP/1.1\r\nHost: d2jkinvisak5y9.cloudfront.net:443\r\nUser-Agent: Go-http-client/1.1\r\nConnection: Upgrade\r\nHidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ\r\nSec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==\r\nSec-WebSocket-Version: 13\r\nUpgrade: websocket\r\n\r\n"
req := []byte(reqStr)
local.Write(req)
ret := <-retChan
assert.Equal(t, len(req), ret.n)
assert.Equal(t, req, buf[:ret.n])
assert.IsType(t, WebSocket{}, ret.transport)
assert.NoError(t, ret.err)
})
t.Run("Good WebSocket but buf too small", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 10)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
reqStr := "GET / HTTP/1.1\r\nHost: d2jkinvisak5y9.cloudfront.net:443\r\nUser-Agent: Go-http-client/1.1\r\nConnection: Upgrade\r\nHidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ\r\nSec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==\r\nSec-WebSocket-Version: 13\r\nUpgrade: websocket\r\n\r\n"
req := []byte(reqStr)
local.Write(req)
ret := <-retChan
assert.Equal(t, io.ErrShortBuffer, ret.err)
assert.True(t, ret.redirOnErr)
assert.Equal(t, req[:ret.n], buf[:ret.n])
})
t.Run("Incomplete WebSocket timeout", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
reqStr := "GET /"
req := []byte(reqStr)
local.Write(req)
select {
case ret := <-retChan:
assert.Equal(t, len(req), ret.n)
assert.False(t, ret.redirOnErr)
assert.Error(t, ret.err)
case <-time.After(2 * timeout):
assert.Fail(t, "readFirstPacket should have timed out")
}
})
t.Run("Staggered WebSocket", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
reqStr := "GET / HTTP/1.1\r\nHost: d2jkinvisak5y9.cloudfront.net:443\r\nUser-Agent: Go-http-client/1.1\r\nConnection: Upgrade\r\nHidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ\r\nSec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==\r\nSec-WebSocket-Version: 13\r\nUpgrade: websocket\r\n\r\n"
req := []byte(reqStr)
local.Write(req[:100])
time.Sleep(timeout / 2)
local.Write(req[100:])
ret := <-retChan
assert.Equal(t, len(req), ret.n)
assert.Equal(t, req, buf[:ret.n])
assert.IsType(t, WebSocket{}, ret.transport)
assert.NoError(t, ret.err)
})
}

View File

@ -1,64 +0,0 @@
//go:build gofuzz
// +build gofuzz
package server
import (
"errors"
"net"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/connutil"
)
type rfpReturnValue_fuzz struct {
n int
transport Transport
redirOnErr bool
err error
}
func Fuzz(data []byte) int {
var bypassUID [16]byte
var pv [32]byte
sta := &State{
BypassUID: map[[16]byte]struct{}{
bypassUID: {},
},
ProxyBook: map[string]net.Addr{
"shadowsocks": nil,
},
UsedRandom: map[[32]byte]int64{},
StaticPv: &pv,
WorldState: common.RealWorldState,
}
rfp := func(conn net.Conn, buf []byte, retChan chan<- rfpReturnValue_fuzz) {
ret := rfpReturnValue_fuzz{}
ret.n, ret.transport, ret.redirOnErr, ret.err = readFirstPacket(conn, buf, 500*time.Millisecond)
retChan <- ret
}
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue_fuzz)
go rfp(remote, buf, retChan)
local.Write(data)
ret := <-retChan
if ret.err != nil {
return 1
}
_, _, err := AuthFirstPacket(buf[:ret.n], ret.transport, sta)
if !errors.Is(err, ErrReplay) && !errors.Is(err, ErrBadDecryption) {
return 1
}
return 0
}

View File

@ -2,200 +2,174 @@ package server
import ( import (
"crypto" "crypto"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/cbeuw/Cloak/internal/server/usermanager"
"github.com/sirupsen/logrus"
"io/ioutil" "io/ioutil"
"net" "net"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/cbeuw/Cloak/internal/common" gmux "github.com/gorilla/mux"
"github.com/cbeuw/Cloak/internal/server/usermanager"
) )
type RawConfig struct { type rawConfig struct {
ProxyBook map[string][]string ProxyBook map[string][]string
BindAddr []string BindAddr []string
BypassUID [][]byte BypassUID [][]byte
RedirAddr string RedirAddr string
PrivateKey []byte PrivateKey string
AdminUID []byte AdminUID string
DatabasePath string DatabasePath string
KeepAlive int StreamTimeout int
CncMode bool CncMode bool
} }
// State type stores the global state of the program // State type stores the global state of the program
type State struct { type State struct {
BindAddr []net.Addr
ProxyBook map[string]net.Addr ProxyBook map[string]net.Addr
ProxyDialer common.Dialer
WorldState common.WorldState Now func() time.Time
AdminUID []byte AdminUID []byte
Timeout time.Duration
BypassUID map[[16]byte]struct{} BypassUID map[[16]byte]struct{}
StaticPv crypto.PrivateKey staticPv crypto.PrivateKey
// TODO: this doesn't have to be a net.Addr; resolution is done in Dial automatically RedirAddr net.Addr
RedirHost net.Addr
RedirPort string
RedirDialer common.Dialer
usedRandomM sync.RWMutex usedRandomM sync.RWMutex
UsedRandom map[[32]byte]int64 usedRandom map[[32]byte]int64
Panel *userPanel Panel *userPanel
LocalAPIRouter *gmux.Router
} }
func parseRedirAddr(redirAddr string) (net.Addr, string, error) { func InitState(nowFunc func() time.Time) (*State, error) {
var host string ret := &State{
var port string Now: nowFunc,
BypassUID: make(map[[16]byte]struct{}),
ProxyBook: map[string]net.Addr{},
usedRandom: map[[32]byte]int64{},
}
go ret.UsedRandomCleaner()
return ret, nil
}
// ParseConfig parses the config (either a path to json or the json itself as argument) into a State variable
func (sta *State) ParseConfig(conf string) (err error) {
var content []byte
var preParse rawConfig
content, errPath := ioutil.ReadFile(conf)
if errPath != nil {
errJson := json.Unmarshal(content, &preParse)
if errJson != nil {
return errors.New("Failed to read/unmarshal configuration, path is invalid or " + errJson.Error())
}
} else {
errJson := json.Unmarshal(content, &preParse)
if errJson != nil {
return errors.New("Failed to read configuration file: " + errJson.Error())
}
}
if preParse.CncMode {
//TODO: implement command & control mode
} else {
manager, err := usermanager.MakeLocalManager(preParse.DatabasePath)
if err != nil {
return err
}
sta.Panel = MakeUserPanel(manager)
sta.LocalAPIRouter = manager.Router
}
if preParse.StreamTimeout == 0 {
sta.Timeout = time.Duration(300) * time.Second
} else {
sta.Timeout = time.Duration(preParse.StreamTimeout) * time.Second
}
redirAddr := preParse.RedirAddr
colonSep := strings.Split(redirAddr, ":") colonSep := strings.Split(redirAddr, ":")
if len(colonSep) > 1 { if len(colonSep) != 0 {
if len(colonSep) == 2 { if len(colonSep) == 2 {
// domain or ipv4 with port logrus.Error("If RedirAddr contains a port number, please remove it.")
host = colonSep[0] redirAddr = colonSep[0]
port = colonSep[1]
} else { } else {
if strings.Contains(redirAddr, "[") { if strings.Contains(redirAddr, "[") {
// ipv6 with port logrus.Error("If RedirAddr contains a port number, please remove it.")
port = colonSep[len(colonSep)-1] redirAddr = strings.TrimRight(redirAddr, "]:"+colonSep[len(colonSep)-1])
host = strings.TrimSuffix(redirAddr, "]:"+port) redirAddr = strings.TrimPrefix(redirAddr, "[")
host = strings.TrimPrefix(host, "[")
} else {
// ipv6 without port
host = redirAddr
} }
} }
} else {
// domain or ipv4 without port
host = redirAddr
} }
redirHost, err := net.ResolveIPAddr("ip", host) sta.RedirAddr, err = net.ResolveIPAddr("ip", redirAddr)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unable to resolve RedirAddr: %v. ", err) return fmt.Errorf("unable to resolve RedirAddr: %v. ", err)
} }
return redirHost, port, nil
}
func parseProxyBook(bookEntries map[string][]string) (map[string]net.Addr, error) { for _, addr := range preParse.BindAddr {
proxyBook := map[string]net.Addr{} bindAddr, err := net.ResolveTCPAddr("tcp", addr)
for name, pair := range bookEntries { if err != nil {
return err
}
sta.BindAddr = append(sta.BindAddr, bindAddr)
}
for name, pair := range preParse.ProxyBook {
name = strings.ToLower(name) name = strings.ToLower(name)
if len(pair) != 2 { if len(pair) != 2 {
return nil, fmt.Errorf("invalid proxy endpoint and address pair for %v: %v", name, pair) return fmt.Errorf("invalid proxy endpoint and address pair for %v: %v", name, pair)
} }
network := strings.ToLower(pair[0]) network := strings.ToLower(pair[0])
switch network { switch network {
case "tcp": case "tcp":
addr, err := net.ResolveTCPAddr("tcp", pair[1]) addr, err := net.ResolveTCPAddr("tcp", pair[1])
if err != nil { if err != nil {
return nil, err return err
} }
proxyBook[name] = addr sta.ProxyBook[name] = addr
continue continue
case "udp": case "udp":
addr, err := net.ResolveUDPAddr("udp", pair[1]) addr, err := net.ResolveUDPAddr("udp", pair[1])
if err != nil { if err != nil {
return nil, err return err
} }
proxyBook[name] = addr sta.ProxyBook[name] = addr
continue continue
} }
} }
return proxyBook, nil
}
// ParseConfig reads the config file or semicolon-separated options and parse them into a RawConfig pvBytes, err := base64.StdEncoding.DecodeString(preParse.PrivateKey)
func ParseConfig(conf string) (raw RawConfig, err error) {
content, errPath := ioutil.ReadFile(conf)
if errPath != nil {
errJson := json.Unmarshal(content, &raw)
if errJson != nil {
err = fmt.Errorf("failed to read/unmarshal configuration, path is invalid or %v", errJson)
return
}
} else {
errJson := json.Unmarshal(content, &raw)
if errJson != nil {
err = fmt.Errorf("failed to read configuration file: %v", errJson)
return
}
}
if raw.ProxyBook == nil {
raw.ProxyBook = make(map[string][]string)
}
return
}
// InitState process the RawConfig and initialises a server State accordingly
func InitState(preParse RawConfig, worldState common.WorldState) (sta *State, err error) {
sta = &State{
BypassUID: make(map[[16]byte]struct{}),
ProxyBook: map[string]net.Addr{},
UsedRandom: map[[32]byte]int64{},
RedirDialer: &net.Dialer{},
WorldState: worldState,
}
if preParse.CncMode {
err = errors.New("command & control mode not implemented")
return
} else {
var manager usermanager.UserManager
if len(preParse.AdminUID) == 0 || preParse.DatabasePath == "" {
manager = &usermanager.Voidmanager{}
} else {
manager, err = usermanager.MakeLocalManager(preParse.DatabasePath, worldState)
if err != nil { if err != nil {
return sta, err return errors.New("Failed to decode private key: " + err.Error())
}
}
sta.Panel = MakeUserPanel(manager)
}
if preParse.KeepAlive <= 0 {
sta.ProxyDialer = &net.Dialer{KeepAlive: -1}
} else {
sta.ProxyDialer = &net.Dialer{KeepAlive: time.Duration(preParse.KeepAlive) * time.Second}
}
sta.RedirHost, sta.RedirPort, err = parseRedirAddr(preParse.RedirAddr)
if err != nil {
err = fmt.Errorf("unable to parse RedirAddr: %v", err)
return
}
sta.ProxyBook, err = parseProxyBook(preParse.ProxyBook)
if err != nil {
err = fmt.Errorf("unable to parse ProxyBook: %v", err)
return
}
if len(preParse.PrivateKey) == 0 {
err = fmt.Errorf("must have a valid private key. Run `ck-server -key` to generate one")
return
} }
var pv [32]byte var pv [32]byte
copy(pv[:], preParse.PrivateKey) copy(pv[:], pvBytes)
sta.StaticPv = &pv sta.staticPv = &pv
sta.AdminUID = preParse.AdminUID adminUID, err := base64.StdEncoding.DecodeString(preParse.AdminUID)
if err != nil {
return errors.New("Failed to decode AdminUID: " + err.Error())
}
sta.AdminUID = adminUID
var arrUID [16]byte var arrUID [16]byte
for _, UID := range preParse.BypassUID { for _, UID := range preParse.BypassUID {
copy(arrUID[:], UID) copy(arrUID[:], UID)
sta.BypassUID[arrUID] = struct{}{} sta.BypassUID[arrUID] = struct{}{}
} }
if len(sta.AdminUID) != 0 { copy(arrUID[:], adminUID)
copy(arrUID[:], sta.AdminUID)
sta.BypassUID[arrUID] = struct{}{} sta.BypassUID[arrUID] = struct{}{}
} return nil
go sta.UsedRandomCleaner()
return sta, nil
} }
// IsBypass checks if a UID is a bypass user // IsBypass checks if a UID is a bypass user
@ -206,28 +180,31 @@ func (sta *State) IsBypass(UID []byte) bool {
return exist return exist
} }
const timestampTolerance = 180 * time.Second const TIMESTAMP_TOLERANCE = 180 * time.Second
const replayCacheAgeLimit = 12 * time.Hour const CACHE_CLEAN_INTERVAL = 12 * time.Hour
// UsedRandomCleaner clears the cache of used random fields every replayCacheAgeLimit // UsedRandomCleaner clears the cache of used random fields every CACHE_CLEAN_INTERVAL
func (sta *State) UsedRandomCleaner() { func (sta *State) UsedRandomCleaner() {
for { for {
time.Sleep(replayCacheAgeLimit) time.Sleep(CACHE_CLEAN_INTERVAL)
now := sta.Now()
sta.usedRandomM.Lock() sta.usedRandomM.Lock()
for key, t := range sta.UsedRandom { for key, t := range sta.usedRandom {
if time.Unix(t, 0).Before(sta.WorldState.Now().Add(timestampTolerance)) { if time.Unix(t, 0).Before(now.Add(TIMESTAMP_TOLERANCE)) {
delete(sta.UsedRandom, key) delete(sta.usedRandom, key)
} }
} }
sta.usedRandomM.Unlock() sta.usedRandomM.Unlock()
} }
} }
func (sta *State) registerRandom(r [32]byte) bool { func (sta *State) registerRandom(r []byte) bool {
var random [32]byte
copy(random[:], r)
sta.usedRandomM.Lock() sta.usedRandomM.Lock()
_, used := sta.UsedRandom[r] _, used := sta.usedRandom[random]
sta.UsedRandom[r] = sta.WorldState.Now().Unix() sta.usedRandom[random] = sta.Now().Unix()
sta.usedRandomM.Unlock() sta.usedRandomM.Unlock()
return used return used
} }

View File

@ -1,126 +0,0 @@
package server
import (
"net"
"testing"
)
func TestParseRedirAddr(t *testing.T) {
t.Run("ipv4 without port", func(t *testing.T) {
ipv4noPort := "1.2.3.4"
host, port, err := parseRedirAddr(ipv4noPort)
if err != nil {
t.Errorf("parsing %v error: %v", ipv4noPort, err)
return
}
if host.String() != "1.2.3.4" {
t.Errorf("expected %v got %v", "1.2.3.4", host.String())
}
if port != "" {
t.Errorf("port not empty when there is no port")
}
})
t.Run("ipv4 with port", func(t *testing.T) {
ipv4wPort := "1.2.3.4:1234"
host, port, err := parseRedirAddr(ipv4wPort)
if err != nil {
t.Errorf("parsing %v error: %v", ipv4wPort, err)
return
}
if host.String() != "1.2.3.4" {
t.Errorf("expected %v got %v", "1.2.3.4", host.String())
}
if port != "1234" {
t.Errorf("wrong port: expected %v, got %v", "1234", port)
}
})
t.Run("domain without port", func(t *testing.T) {
domainNoPort := "example.com"
host, port, err := parseRedirAddr(domainNoPort)
if err != nil {
t.Errorf("parsing %v error: %v", domainNoPort, err)
return
}
expIPs, err := net.LookupIP("example.com")
if err != nil {
t.Errorf("tester error: cannot resolve example.com: %v", err)
return
}
contain := false
for _, expIP := range expIPs {
if expIP.String() == host.String() {
contain = true
}
}
if !contain {
t.Errorf("expected one of %v got %v", expIPs, host.String())
}
if port != "" {
t.Errorf("port not empty when there is no port")
}
})
t.Run("domain with port", func(t *testing.T) {
domainWPort := "example.com:80"
host, port, err := parseRedirAddr(domainWPort)
if err != nil {
t.Errorf("parsing %v error: %v", domainWPort, err)
return
}
expIPs, err := net.LookupIP("example.com")
if err != nil {
t.Errorf("tester error: cannot resolve example.com: %v", err)
return
}
contain := false
for _, expIP := range expIPs {
if expIP.String() == host.String() {
contain = true
}
}
if !contain {
t.Errorf("expected one of %v got %v", expIPs, host.String())
}
if port != "80" {
t.Errorf("wrong port: expected %v, got %v", "80", port)
}
})
t.Run("ipv6 without port", func(t *testing.T) {
ipv6noPort := "a:b:c:d::"
host, port, err := parseRedirAddr(ipv6noPort)
if err != nil {
t.Errorf("parsing %v error: %v", ipv6noPort, err)
return
}
if host.String() != "a:b:c:d::" {
t.Errorf("expected %v got %v", "a:b:c:d::", host.String())
}
if port != "" {
t.Errorf("port not empty when there is no port")
}
})
t.Run("ipv6 with port", func(t *testing.T) {
ipv6wPort := "[a:b:c:d::]:80"
host, port, err := parseRedirAddr(ipv6wPort)
if err != nil {
t.Errorf("parsing %v error: %v", ipv6wPort, err)
return
}
if host.String() != "a:b:c:d::" {
t.Errorf("expected %v got %v", "a:b:c:d::", host.String())
}
if port != "80" {
t.Errorf("wrong port: expected %v, got %v", "80", port)
}
})
}

View File

@ -1,16 +1,23 @@
package server package server
import ( import (
"crypto" "github.com/cbeuw/Cloak/internal/util"
"errors"
"io"
"net" "net"
) )
type Responder = func(originalConn net.Conn, sessionKey [32]byte, randSource io.Reader) (preparedConn net.Conn, err error)
type Transport interface { type Transport interface {
processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (authFragments, Responder, error) HasRecordLayer() bool
UnitReadFunc() func(net.Conn, []byte) (int, error)
} }
var ErrInvalidPubKey = errors.New("public key has invalid format") type TLS struct{}
var ErrCiphertextLength = errors.New("ciphertext has the wrong length")
func (TLS) String() string { return "TLS" }
func (TLS) HasRecordLayer() bool { return true }
func (TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS }
type WebSocket struct{}
func (WebSocket) String() string { return "WebSocket" }
func (WebSocket) HasRecordLayer() bool { return false }
func (WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket }

View File

@ -1,129 +0,0 @@
package usermanager
import (
"bytes"
"encoding/base64"
"encoding/json"
"net/http"
gmux "github.com/gorilla/mux"
)
type APIRouter struct {
*gmux.Router
manager UserManager
}
func APIRouterOf(manager UserManager) *APIRouter {
ret := &APIRouter{
manager: manager,
}
ret.registerMux()
return ret
}
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
next.ServeHTTP(w, r)
})
}
func (ar *APIRouter) registerMux() {
ar.Router = gmux.NewRouter()
ar.HandleFunc("/admin/users", ar.listAllUsersHlr).Methods("GET")
ar.HandleFunc("/admin/users/{UID}", ar.getUserInfoHlr).Methods("GET")
ar.HandleFunc("/admin/users/{UID}", ar.writeUserInfoHlr).Methods("POST")
ar.HandleFunc("/admin/users/{UID}", ar.deleteUserHlr).Methods("DELETE")
ar.Methods("OPTIONS").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS")
})
ar.Use(corsMiddleware)
}
func (ar *APIRouter) listAllUsersHlr(w http.ResponseWriter, r *http.Request) {
infos, err := ar.manager.ListAllUsers()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
resp, err := json.Marshal(infos)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, _ = w.Write(resp)
}
func (ar *APIRouter) getUserInfoHlr(w http.ResponseWriter, r *http.Request) {
b64UID := gmux.Vars(r)["UID"]
if b64UID == "" {
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
}
UID, err := base64.URLEncoding.DecodeString(b64UID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
uinfo, err := ar.manager.GetUserInfo(UID)
if err == ErrUserNotFound {
http.Error(w, ErrUserNotFound.Error(), http.StatusNotFound)
return
}
resp, err := json.Marshal(uinfo)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, _ = w.Write(resp)
}
func (ar *APIRouter) writeUserInfoHlr(w http.ResponseWriter, r *http.Request) {
b64UID := gmux.Vars(r)["UID"]
if b64UID == "" {
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
return
}
UID, err := base64.URLEncoding.DecodeString(b64UID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var uinfo UserInfo
err = json.NewDecoder(r.Body).Decode(&uinfo)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if !bytes.Equal(UID, uinfo.UID) {
http.Error(w, "UID mismatch", http.StatusBadRequest)
}
err = ar.manager.WriteUserInfo(uinfo)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
w.WriteHeader(http.StatusCreated)
}
func (ar *APIRouter) deleteUserHlr(w http.ResponseWriter, r *http.Request) {
b64UID := gmux.Vars(r)["UID"]
if b64UID == "" {
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
return
}
UID, err := base64.URLEncoding.DecodeString(b64UID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
err = ar.manager.DeleteUser(UID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
w.WriteHeader(http.StatusOK)
}

View File

@ -1,214 +0,0 @@
package usermanager
import (
"bytes"
"encoding/base64"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
var mockUIDb64 = base64.URLEncoding.EncodeToString(mockUID)
func makeRouter(t *testing.T) (router *APIRouter, cleaner func()) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
cleaner = func() { os.Remove(tmpDB.Name()) }
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
if err != nil {
t.Fatal(err)
}
router = APIRouterOf(mgr)
return router, cleaner
}
func TestWriteUserInfoHlr(t *testing.T) {
router, cleaner := makeRouter(t)
defer cleaner()
marshalled, err := json.Marshal(mockUserInfo)
if err != nil {
t.Fatal(err)
}
t.Run("ok", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body)
})
t.Run("partial update", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled))
assert.NoError(t, err)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
partialUserInfo := UserInfo{
UID: mockUID,
SessionsCap: JustInt32(10),
}
partialMarshalled, _ := json.Marshal(partialUserInfo)
req, err = http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(partialMarshalled))
assert.NoError(t, err)
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
req, err = http.NewRequest("GET", "/admin/users/"+mockUIDb64, nil)
assert.NoError(t, err)
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
var got UserInfo
err = json.Unmarshal(rr.Body.Bytes(), &got)
assert.NoError(t, err)
expected := mockUserInfo
expected.SessionsCap = partialUserInfo.SessionsCap
assert.EqualValues(t, expected, got)
})
t.Run("empty parameter", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusMethodNotAllowed, rr.Code, "response body: %v", rr.Body)
})
t.Run("UID mismatch", func(t *testing.T) {
badMock := mockUserInfo
badMock.UID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0}
badMarshal, err := json.Marshal(badMock)
if err != nil {
t.Fatal(err)
}
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(badMarshal))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusBadRequest, rr.Code, "response body: %v", rr.Body)
})
t.Run("garbage data", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer([]byte(`{"{{'{;;}}}1`)))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusBadRequest, rr.Code, "response body: %v", rr.Body)
})
t.Run("not base64", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/"+"defonotbase64", bytes.NewBuffer(marshalled))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusBadRequest, rr.Code, "response body: %v", rr.Body)
})
}
func addUser(t *testing.T, router *APIRouter, user UserInfo) {
marshalled, err := json.Marshal(user)
if err != nil {
t.Fatal(err)
}
req, err := http.NewRequest("POST", "/admin/users/"+base64.URLEncoding.EncodeToString(user.UID), bytes.NewBuffer(marshalled))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body)
}
func TestGetUserInfoHlr(t *testing.T) {
router, cleaner := makeRouter(t)
defer cleaner()
t.Run("empty parameter", func(t *testing.T) {
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/", nil)
})
t.Run("non-existent", func(t *testing.T) {
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/"+base64.URLEncoding.EncodeToString([]byte("adsf")), nil)
})
t.Run("not base64", func(t *testing.T) {
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/"+"defonotbase64", nil)
})
t.Run("ok", func(t *testing.T) {
addUser(t, router, mockUserInfo)
var got UserInfo
err := json.Unmarshal([]byte(assert.HTTPBody(router.ServeHTTP, "GET", "/admin/users/"+mockUIDb64, nil)), &got)
if err != nil {
t.Fatal(err)
}
assert.EqualValues(t, mockUserInfo, got)
})
}
func TestDeleteUserHlr(t *testing.T) {
router, cleaner := makeRouter(t)
defer cleaner()
t.Run("non-existent", func(t *testing.T) {
assert.HTTPError(t, router.ServeHTTP, "DELETE", "/admin/users/"+base64.URLEncoding.EncodeToString([]byte("adsf")), nil)
})
t.Run("not base64", func(t *testing.T) {
assert.HTTPError(t, router.ServeHTTP, "DELETE", "/admin/users/"+"defonotbase64", nil)
})
t.Run("ok", func(t *testing.T) {
addUser(t, router, mockUserInfo)
assert.HTTPSuccess(t, router.ServeHTTP, "DELETE", "/admin/users/"+mockUIDb64, nil)
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/"+mockUIDb64, nil)
})
}
func TestListAllUsersHlr(t *testing.T) {
router, cleaner := makeRouter(t)
defer cleaner()
user1 := mockUserInfo
addUser(t, router, user1)
user2 := mockUserInfo
user2.UID = []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}
addUser(t, router, user2)
expected := []UserInfo{user1, user2}
var got []UserInfo
err := json.Unmarshal([]byte(assert.HTTPBody(router.ServeHTTP, "GET", "/admin/users", nil)), &got)
if err != nil {
t.Fatal(err)
}
assert.True(t, assert.Subset(t, got, expected), assert.Subset(t, expected, got))
}

View File

@ -2,44 +2,68 @@ package usermanager
import ( import (
"encoding/binary" "encoding/binary"
"github.com/cbeuw/Cloak/internal/common"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt" "net/http"
"time"
"github.com/boltdb/bolt"
gmux "github.com/gorilla/mux"
) )
var u32 = binary.BigEndian.Uint32 var Uint32 = binary.BigEndian.Uint32
var u64 = binary.BigEndian.Uint64 var Uint64 = binary.BigEndian.Uint64
var PutUint32 = binary.BigEndian.PutUint32
var PutUint64 = binary.BigEndian.PutUint64
func i64ToB(value int64) []byte { func i64ToB(value int64) []byte {
oct := make([]byte, 8) oct := make([]byte, 8)
binary.BigEndian.PutUint64(oct, uint64(value)) PutUint64(oct, uint64(value))
return oct return oct
} }
func i32ToB(value int32) []byte { func i32ToB(value int32) []byte {
nib := make([]byte, 4) nib := make([]byte, 4)
binary.BigEndian.PutUint32(nib, uint32(value)) PutUint32(nib, uint32(value))
return nib return nib
} }
// localManager is responsible for managing the local user database // localManager is responsible for routing API calls to appropriate handlers and manage the local user database accordingly
type localManager struct { type localManager struct {
db *bolt.DB db *bolt.DB
world common.WorldState Router *gmux.Router
} }
func MakeLocalManager(dbPath string, worldState common.WorldState) (*localManager, error) { func MakeLocalManager(dbPath string) (*localManager, error) {
db, err := bolt.Open(dbPath, 0600, nil) db, err := bolt.Open(dbPath, 0600, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ret := &localManager{ ret := &localManager{
db: db, db: db,
world: worldState,
} }
ret.Router = ret.registerMux()
return ret, nil return ret, nil
} }
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
next.ServeHTTP(w, r)
})
}
func (manager *localManager) registerMux() *gmux.Router {
r := gmux.NewRouter()
r.HandleFunc("/admin/users", manager.listAllUsersHlr).Methods("GET")
r.HandleFunc("/admin/users/{UID}", manager.getUserInfoHlr).Methods("GET")
r.HandleFunc("/admin/users/{UID}", manager.writeUserInfoHlr).Methods("POST")
r.HandleFunc("/admin/users/{UID}", manager.deleteUserHlr).Methods("DELETE")
r.Methods("OPTIONS").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS")
})
r.Use(corsMiddleware)
return r
}
// Authenticate user returns err==nil along with the users' up and down bandwidths if the UID is allowed to connect // Authenticate user returns err==nil along with the users' up and down bandwidths if the UID is allowed to connect
// More specifically it checks that the user exists, that it has positive credit and that it hasn't expired // More specifically it checks that the user exists, that it has positive credit and that it hasn't expired
func (manager *localManager) AuthenticateUser(UID []byte) (int64, int64, error) { func (manager *localManager) AuthenticateUser(UID []byte) (int64, int64, error) {
@ -49,11 +73,11 @@ func (manager *localManager) AuthenticateUser(UID []byte) (int64, int64, error)
if bucket == nil { if bucket == nil {
return ErrUserNotFound return ErrUserNotFound
} }
upRate = int64(u64(bucket.Get([]byte("UpRate")))) upRate = int64(Uint64(bucket.Get([]byte("UpRate"))))
downRate = int64(u64(bucket.Get([]byte("DownRate")))) downRate = int64(Uint64(bucket.Get([]byte("DownRate"))))
upCredit = int64(u64(bucket.Get([]byte("UpCredit")))) upCredit = int64(Uint64(bucket.Get([]byte("UpCredit"))))
downCredit = int64(u64(bucket.Get([]byte("DownCredit")))) downCredit = int64(Uint64(bucket.Get([]byte("DownCredit"))))
expiryTime = int64(u64(bucket.Get([]byte("ExpiryTime")))) expiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime"))))
return nil return nil
}) })
if err != nil { if err != nil {
@ -65,7 +89,7 @@ func (manager *localManager) AuthenticateUser(UID []byte) (int64, int64, error)
if downCredit <= 0 { if downCredit <= 0 {
return 0, 0, ErrNoDownCredit return 0, 0, ErrNoDownCredit
} }
if expiryTime < manager.world.Now().Unix() { if expiryTime < time.Now().Unix() {
return 0, 0, ErrUserExpired return 0, 0, ErrUserExpired
} }
@ -84,10 +108,10 @@ func (manager *localManager) AuthoriseNewSession(UID []byte, ainfo Authorisation
if bucket == nil { if bucket == nil {
return ErrUserNotFound return ErrUserNotFound
} }
sessionsCap = int(u32(bucket.Get([]byte("SessionsCap")))) sessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap"))))
upCredit = int64(u64(bucket.Get([]byte("UpCredit")))) upCredit = int64(Uint64(bucket.Get([]byte("UpCredit"))))
downCredit = int64(u64(bucket.Get([]byte("DownCredit")))) downCredit = int64(Uint64(bucket.Get([]byte("DownCredit"))))
expiryTime = int64(u64(bucket.Get([]byte("ExpiryTime")))) expiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime"))))
return nil return nil
}) })
if err != nil { if err != nil {
@ -99,7 +123,7 @@ func (manager *localManager) AuthoriseNewSession(UID []byte, ainfo Authorisation
if downCredit <= 0 { if downCredit <= 0 {
return ErrNoDownCredit return ErrNoDownCredit
} }
if expiryTime < manager.world.Now().Unix() { if expiryTime < time.Now().Unix() {
return ErrUserExpired return ErrUserExpired
} }
@ -114,9 +138,6 @@ func (manager *localManager) AuthoriseNewSession(UID []byte, ainfo Authorisation
// If no action is needed, there won't be a StatusResponse entry for that user // If no action is needed, there won't be a StatusResponse entry for that user
func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusResponse, error) { func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusResponse, error) {
var responses []StatusResponse var responses []StatusResponse
if len(uploads) == 0 {
return responses, nil
}
err := manager.db.Update(func(tx *bolt.Tx) error { err := manager.db.Update(func(tx *bolt.Tx) error {
for _, status := range uploads { for _, status := range uploads {
var resp StatusResponse var resp StatusResponse
@ -131,7 +152,7 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo
continue continue
} }
oldUp := int64(u64(bucket.Get([]byte("UpCredit")))) oldUp := int64(Uint64(bucket.Get([]byte("UpCredit"))))
newUp := oldUp - status.UpUsage newUp := oldUp - status.UpUsage
if newUp <= 0 { if newUp <= 0 {
resp = StatusResponse{ resp = StatusResponse{
@ -140,13 +161,15 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo
"No upload credit left", "No upload credit left",
} }
responses = append(responses, resp) responses = append(responses, resp)
continue
} }
err := bucket.Put([]byte("UpCredit"), i64ToB(newUp)) err := bucket.Put([]byte("UpCredit"), i64ToB(newUp))
if err != nil { if err != nil {
log.Error(err) log.Error(err)
continue
} }
oldDown := int64(u64(bucket.Get([]byte("DownCredit")))) oldDown := int64(Uint64(bucket.Get([]byte("DownCredit"))))
newDown := oldDown - status.DownUsage newDown := oldDown - status.DownUsage
if newDown <= 0 { if newDown <= 0 {
resp = StatusResponse{ resp = StatusResponse{
@ -155,20 +178,23 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo
"No download credit left", "No download credit left",
} }
responses = append(responses, resp) responses = append(responses, resp)
continue
} }
err = bucket.Put([]byte("DownCredit"), i64ToB(newDown)) err = bucket.Put([]byte("DownCredit"), i64ToB(newDown))
if err != nil { if err != nil {
log.Error(err) log.Error(err)
continue
} }
expiry := int64(u64(bucket.Get([]byte("ExpiryTime")))) expiry := int64(Uint64(bucket.Get([]byte("ExpiryTime"))))
if manager.world.Now().Unix() > expiry { if time.Now().Unix() > expiry {
resp = StatusResponse{ resp = StatusResponse{
status.UID, status.UID,
TERMINATE, TERMINATE,
"User has expired", "User has expired",
} }
responses = append(responses, resp) responses = append(responses, resp)
continue
} }
} }
return nil return nil
@ -176,94 +202,6 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo
return responses, err return responses, err
} }
func (manager *localManager) ListAllUsers() (infos []UserInfo, err error) {
err = manager.db.View(func(tx *bolt.Tx) error {
err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error {
var uinfo UserInfo
uinfo.UID = UID
uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
infos = append(infos, uinfo)
return nil
})
return err
})
if infos == nil {
infos = []UserInfo{}
}
return
}
func (manager *localManager) GetUserInfo(UID []byte) (uinfo UserInfo, err error) {
err = manager.db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(UID)
if bucket == nil {
return ErrUserNotFound
}
uinfo.UID = UID
uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
return nil
})
return
}
func (manager *localManager) WriteUserInfo(u UserInfo) (err error) {
err = manager.db.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists(u.UID)
if err != nil {
return err
}
if u.SessionsCap != nil {
if err = bucket.Put([]byte("SessionsCap"), i32ToB(*u.SessionsCap)); err != nil {
return err
}
}
if u.UpRate != nil {
if err = bucket.Put([]byte("UpRate"), i64ToB(*u.UpRate)); err != nil {
return err
}
}
if u.DownRate != nil {
if err = bucket.Put([]byte("DownRate"), i64ToB(*u.DownRate)); err != nil {
return err
}
}
if u.UpCredit != nil {
if err = bucket.Put([]byte("UpCredit"), i64ToB(*u.UpCredit)); err != nil {
return err
}
}
if u.DownCredit != nil {
if err = bucket.Put([]byte("DownCredit"), i64ToB(*u.DownCredit)); err != nil {
return err
}
}
if u.ExpiryTime != nil {
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(*u.ExpiryTime)); err != nil {
return err
}
}
return nil
})
return
}
func (manager *localManager) DeleteUser(UID []byte) (err error) {
err = manager.db.Update(func(tx *bolt.Tx) error {
return tx.DeleteBucket(UID)
})
return
}
func (manager *localManager) Close() error { func (manager *localManager) Close() error {
return manager.db.Close() return manager.db.Close()
} }

View File

@ -0,0 +1,165 @@
package usermanager
import (
"bytes"
"encoding/base64"
"encoding/json"
"github.com/boltdb/bolt"
"net/http"
gmux "github.com/gorilla/mux"
)
type UserInfo struct {
UID []byte
SessionsCap int
UpRate int64
DownRate int64
UpCredit int64
DownCredit int64
ExpiryTime int64
}
func (manager *localManager) listAllUsersHlr(w http.ResponseWriter, r *http.Request) {
var infos []UserInfo
_ = manager.db.View(func(tx *bolt.Tx) error {
err := tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error {
var uinfo UserInfo
uinfo.UID = UID
uinfo.SessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap"))))
uinfo.UpRate = int64(Uint64(bucket.Get([]byte("UpRate"))))
uinfo.DownRate = int64(Uint64(bucket.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(Uint64(bucket.Get([]byte("UpCredit"))))
uinfo.DownCredit = int64(Uint64(bucket.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime"))))
infos = append(infos, uinfo)
return nil
})
return err
})
resp, err := json.Marshal(infos)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, _ = w.Write(resp)
}
func (manager *localManager) getUserInfoHlr(w http.ResponseWriter, r *http.Request) {
b64UID := gmux.Vars(r)["UID"]
if b64UID == "" {
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
}
UID, err := base64.URLEncoding.DecodeString(b64UID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var uinfo UserInfo
err = manager.db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(UID))
if bucket == nil {
return ErrUserNotFound
}
uinfo.UID = UID
uinfo.SessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap"))))
uinfo.UpRate = int64(Uint64(bucket.Get([]byte("UpRate"))))
uinfo.DownRate = int64(Uint64(bucket.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(Uint64(bucket.Get([]byte("UpCredit"))))
uinfo.DownCredit = int64(Uint64(bucket.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime"))))
return nil
})
if err == ErrUserNotFound {
http.Error(w, ErrUserNotFound.Error(), http.StatusNotFound)
return
}
resp, err := json.Marshal(uinfo)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, _ = w.Write(resp)
}
func (manager *localManager) writeUserInfoHlr(w http.ResponseWriter, r *http.Request) {
b64UID := gmux.Vars(r)["UID"]
if b64UID == "" {
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
return
}
UID, err := base64.URLEncoding.DecodeString(b64UID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
jsonUinfo := r.FormValue("UserInfo")
if jsonUinfo == "" {
http.Error(w, "UserInfo cannot be empty", http.StatusBadRequest)
return
}
var uinfo UserInfo
err = json.Unmarshal([]byte(jsonUinfo), &uinfo)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if !bytes.Equal(UID, uinfo.UID) {
http.Error(w, "UID mismatch", http.StatusBadRequest)
}
err = manager.db.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists(uinfo.UID)
if err != nil {
return err
}
if err = bucket.Put([]byte("SessionsCap"), i32ToB(int32(uinfo.SessionsCap))); err != nil {
return err
}
if err = bucket.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil {
return err
}
if err = bucket.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil {
return err
}
if err = bucket.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil {
return err
}
if err = bucket.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil {
return err
}
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil {
return err
}
return nil
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
w.WriteHeader(http.StatusCreated)
}
func (manager *localManager) deleteUserHlr(w http.ResponseWriter, r *http.Request) {
b64UID := gmux.Vars(r)["UID"]
if b64UID == "" {
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
return
}
UID, err := base64.URLEncoding.DecodeString(b64UID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
err = manager.db.Update(func(tx *bolt.Tx) error {
return tx.DeleteBucket(UID)
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
w.WriteHeader(http.StatusOK)
}

View File

@ -2,7 +2,7 @@ swagger: '2.0'
info: info:
description: | description: |
This is the API of Cloak server This is the API of Cloak server
version: 0.0.2 version: 1.0.0
title: Cloak Server title: Cloak Server
contact: contact:
email: cbeuw.andy@gmail.com email: cbeuw.andy@gmail.com
@ -12,6 +12,8 @@ info:
# host: petstore.swagger.io # host: petstore.swagger.io
# basePath: /v2 # basePath: /v2
tags: tags:
- name: admin
description: Endpoints used by the host administrators
- name: users - name: users
description: Operations related to user controls by admin description: Operations related to user controls by admin
# schemes: # schemes:
@ -20,6 +22,7 @@ paths:
/admin/users: /admin/users:
get: get:
tags: tags:
- admin
- users - users
summary: Show all users summary: Show all users
description: Returns an array of all UserInfo description: Returns an array of all UserInfo
@ -38,6 +41,7 @@ paths:
/admin/users/{UID}: /admin/users/{UID}:
get: get:
tags: tags:
- admin
- users - users
summary: Show userinfo by UID summary: Show userinfo by UID
description: Returns a UserInfo object description: Returns a UserInfo object
@ -64,6 +68,7 @@ paths:
description: internal error description: internal error
post: post:
tags: tags:
- admin
- users - users
summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created
operationId: writeUserInfo operationId: writeUserInfo
@ -95,6 +100,7 @@ paths:
description: internal error description: internal error
delete: delete:
tags: tags:
- admin
- users - users
summary: Deletes a user summary: Deletes a user
operationId: deleteUser operationId: deleteUser

View File

@ -1,365 +0,0 @@
package usermanager
import (
"encoding/binary"
"io/ioutil"
"math/rand"
"os"
"reflect"
"sort"
"sync"
"testing"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/stretchr/testify/assert"
)
var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
var mockUserInfo = UserInfo{
UID: mockUID,
SessionsCap: JustInt32(10),
UpRate: JustInt64(100),
DownRate: JustInt64(1000),
UpCredit: JustInt64(10000),
DownCredit: JustInt64(100000),
ExpiryTime: JustInt64(1000000),
}
func makeManager(t *testing.T) (mgr *localManager, cleaner func()) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
cleaner = func() { os.Remove(tmpDB.Name()) }
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
if err != nil {
t.Fatal(err)
}
return mgr, cleaner
}
func TestLocalManager_WriteUserInfo(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
err := mgr.WriteUserInfo(mockUserInfo)
if err != nil {
t.Error(err)
}
got, err := mgr.GetUserInfo(mockUID)
assert.NoError(t, err)
assert.EqualValues(t, mockUserInfo, got)
/* Partial update */
err = mgr.WriteUserInfo(UserInfo{
UID: mockUID,
SessionsCap: JustInt32(*mockUserInfo.SessionsCap + 1),
})
assert.NoError(t, err)
expected := mockUserInfo
expected.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
got, err = mgr.GetUserInfo(mockUID)
assert.NoError(t, err)
assert.EqualValues(t, expected, got)
}
func TestLocalManager_GetUserInfo(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
t.Run("simple fetch", func(t *testing.T) {
_ = mgr.WriteUserInfo(mockUserInfo)
gotInfo, err := mgr.GetUserInfo(mockUID)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(gotInfo, mockUserInfo) {
t.Errorf("got wrong user info: %v", gotInfo)
}
})
t.Run("update a field", func(t *testing.T) {
_ = mgr.WriteUserInfo(mockUserInfo)
updatedUserInfo := mockUserInfo
updatedUserInfo.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
err := mgr.WriteUserInfo(updatedUserInfo)
if err != nil {
t.Error(err)
}
gotInfo, err := mgr.GetUserInfo(mockUID)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(gotInfo, updatedUserInfo) {
t.Errorf("got wrong user info: %v", updatedUserInfo)
}
})
t.Run("non existent user", func(t *testing.T) {
_, err := mgr.GetUserInfo(make([]byte, 16))
if err != ErrUserNotFound {
t.Errorf("expecting error %v, got %v", ErrUserNotFound, err)
}
})
}
func TestLocalManager_DeleteUser(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
_ = mgr.WriteUserInfo(mockUserInfo)
err := mgr.DeleteUser(mockUID)
if err != nil {
t.Error(err)
}
_, err = mgr.GetUserInfo(mockUID)
if err != ErrUserNotFound {
t.Error("user not deleted")
}
}
var validUserInfo = mockUserInfo
func TestLocalManager_AuthenticateUser(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
if err != nil {
t.Fatal(err)
}
t.Run("normal auth", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
upRate, downRate, err := mgr.AuthenticateUser(validUserInfo.UID)
if err != nil {
t.Error(err)
}
if upRate != *validUserInfo.UpRate || downRate != *validUserInfo.DownRate {
t.Error("wrong up or down rate")
}
})
t.Run("non existent user", func(t *testing.T) {
_, _, err := mgr.AuthenticateUser(make([]byte, 16))
if err != ErrUserNotFound {
t.Error("user found")
}
})
t.Run("expired user", func(t *testing.T) {
expiredUserInfo := validUserInfo
expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
_ = mgr.WriteUserInfo(expiredUserInfo)
_, _, err := mgr.AuthenticateUser(expiredUserInfo.UID)
if err != ErrUserExpired {
t.Error("user not expired")
}
})
t.Run("no credit", func(t *testing.T) {
creditlessUserInfo := validUserInfo
creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = JustInt64(-1), JustInt64(-1)
_ = mgr.WriteUserInfo(creditlessUserInfo)
_, _, err := mgr.AuthenticateUser(creditlessUserInfo.UID)
if err != ErrNoUpCredit && err != ErrNoDownCredit {
t.Error("user not creditless")
}
})
}
func TestLocalManager_AuthoriseNewSession(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
t.Run("normal auth", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
if err != nil {
t.Error(err)
}
})
t.Run("non existent user", func(t *testing.T) {
err := mgr.AuthoriseNewSession(make([]byte, 16), AuthorisationInfo{NumExistingSessions: 0})
if err != ErrUserNotFound {
t.Error("user found")
}
})
t.Run("expired user", func(t *testing.T) {
expiredUserInfo := validUserInfo
expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
_ = mgr.WriteUserInfo(expiredUserInfo)
err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
if err != ErrUserExpired {
t.Error("user not expired")
}
})
t.Run("too many sessions", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(*validUserInfo.SessionsCap + 1)})
if err != ErrSessionsCapReached {
t.Error("session cap not reached")
}
})
}
func TestLocalManager_UploadStatus(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
t.Run("simple update", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
update := StatusUpdate{
UID: validUserInfo.UID,
Active: true,
NumSession: 1,
UpUsage: 10,
DownUsage: 100,
Timestamp: mockWorldState.Now().Unix(),
}
_, err := mgr.UploadStatus([]StatusUpdate{update})
if err != nil {
t.Error(err)
}
updatedUserInfo, err := mgr.GetUserInfo(validUserInfo.UID)
if err != nil {
t.Error(err)
}
if *updatedUserInfo.UpCredit != *validUserInfo.UpCredit-update.UpUsage {
t.Error("up usage incorrect")
}
if *updatedUserInfo.DownCredit != *validUserInfo.DownCredit-update.DownUsage {
t.Error("down usage incorrect")
}
})
badUpdates := []struct {
name string
user UserInfo
update StatusUpdate
}{
{"out of up credit",
validUserInfo,
StatusUpdate{
UID: validUserInfo.UID,
Active: true,
NumSession: 1,
UpUsage: *validUserInfo.UpCredit + 100,
DownUsage: 0,
Timestamp: mockWorldState.Now().Unix(),
},
},
{"out of down credit",
validUserInfo,
StatusUpdate{
UID: validUserInfo.UID,
Active: true,
NumSession: 1,
UpUsage: 0,
DownUsage: *validUserInfo.DownCredit + 100,
Timestamp: mockWorldState.Now().Unix(),
},
},
{"expired",
UserInfo{
UID: mockUID,
SessionsCap: JustInt32(10),
UpRate: JustInt64(0),
DownRate: JustInt64(0),
UpCredit: JustInt64(0),
DownCredit: JustInt64(0),
ExpiryTime: JustInt64(-1),
},
StatusUpdate{
UID: mockUserInfo.UID,
Active: true,
NumSession: 1,
UpUsage: 0,
DownUsage: 0,
Timestamp: mockWorldState.Now().Unix(),
},
},
}
for _, badUpdate := range badUpdates {
t.Run(badUpdate.name, func(t *testing.T) {
_ = mgr.WriteUserInfo(badUpdate.user)
resps, err := mgr.UploadStatus([]StatusUpdate{badUpdate.update})
if err != nil {
t.Error(err)
}
if len(resps) == 0 {
t.Fatal("expecting responses")
}
resp := resps[0]
if resp.Action != TERMINATE {
t.Errorf("didn't terminate when %v", badUpdate.name)
}
})
}
}
func TestLocalManager_ListAllUsers(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
var wg sync.WaitGroup
var users []UserInfo
for i := 0; i < 100; i++ {
randUID := make([]byte, 16)
rand.Read(randUID)
newUser := UserInfo{
UID: randUID,
SessionsCap: JustInt32(rand.Int31()),
UpRate: JustInt64(rand.Int63()),
DownRate: JustInt64(rand.Int63()),
UpCredit: JustInt64(rand.Int63()),
DownCredit: JustInt64(rand.Int63()),
ExpiryTime: JustInt64(rand.Int63()),
}
users = append(users, newUser)
wg.Add(1)
go func() {
err := mgr.WriteUserInfo(newUser)
if err != nil {
t.Fatal(err)
}
wg.Done()
}()
}
wg.Wait()
listedUsers, err := mgr.ListAllUsers()
if err != nil {
t.Error(err)
}
sort.Slice(users, func(i, j int) bool {
return binary.BigEndian.Uint64(users[i].UID[0:8]) < binary.BigEndian.Uint64(users[j].UID[0:8])
})
sort.Slice(listedUsers, func(i, j int) bool {
return binary.BigEndian.Uint64(listedUsers[i].UID[0:8]) < binary.BigEndian.Uint64(listedUsers[j].UID[0:8])
})
if !reflect.DeepEqual(users, listedUsers) {
t.Error("listed users deviates from uploaded ones")
}
}

View File

@ -14,23 +14,6 @@ type StatusUpdate struct {
Timestamp int64 Timestamp int64
} }
type MaybeInt32 *int32
type MaybeInt64 *int64
type UserInfo struct {
UID []byte
SessionsCap MaybeInt32
UpRate MaybeInt64
DownRate MaybeInt64
UpCredit MaybeInt64
DownCredit MaybeInt64
ExpiryTime MaybeInt64
}
func JustInt32(v int32) MaybeInt32 { return &v }
func JustInt64(v int64) MaybeInt64 { return &v }
type StatusResponse struct { type StatusResponse struct {
UID []byte UID []byte
Action int Action int
@ -47,7 +30,6 @@ const (
var ErrUserNotFound = errors.New("UID does not correspond to a user") var ErrUserNotFound = errors.New("UID does not correspond to a user")
var ErrSessionsCapReached = errors.New("Sessions cap has reached") var ErrSessionsCapReached = errors.New("Sessions cap has reached")
var ErrMangerIsVoid = errors.New("cannot perform operation with user manager as database path is not specified")
var ErrNoUpCredit = errors.New("No upload credit left") var ErrNoUpCredit = errors.New("No upload credit left")
var ErrNoDownCredit = errors.New("No download credit left") var ErrNoDownCredit = errors.New("No download credit left")
@ -57,8 +39,4 @@ type UserManager interface {
AuthenticateUser([]byte) (int64, int64, error) AuthenticateUser([]byte) (int64, int64, error)
AuthoriseNewSession([]byte, AuthorisationInfo) error AuthoriseNewSession([]byte, AuthorisationInfo) error
UploadStatus([]StatusUpdate) ([]StatusResponse, error) UploadStatus([]StatusUpdate) ([]StatusResponse, error)
ListAllUsers() ([]UserInfo, error)
GetUserInfo(UID []byte) (UserInfo, error)
WriteUserInfo(UserInfo) error
DeleteUser(UID []byte) error
} }

View File

@ -1,31 +0,0 @@
package usermanager
type Voidmanager struct{}
func (v *Voidmanager) AuthenticateUser(bytes []byte) (int64, int64, error) {
return 0, 0, ErrMangerIsVoid
}
func (v *Voidmanager) AuthoriseNewSession(bytes []byte, info AuthorisationInfo) error {
return ErrMangerIsVoid
}
func (v *Voidmanager) UploadStatus(updates []StatusUpdate) ([]StatusResponse, error) {
return nil, ErrMangerIsVoid
}
func (v *Voidmanager) ListAllUsers() ([]UserInfo, error) {
return []UserInfo{}, ErrMangerIsVoid
}
func (v *Voidmanager) GetUserInfo(UID []byte) (UserInfo, error) {
return UserInfo{}, ErrMangerIsVoid
}
func (v *Voidmanager) WriteUserInfo(info UserInfo) error {
return ErrMangerIsVoid
}
func (v *Voidmanager) DeleteUser(UID []byte) error {
return ErrMangerIsVoid
}

View File

@ -1,44 +0,0 @@
package usermanager
import (
"testing"
"github.com/stretchr/testify/assert"
)
var v = &Voidmanager{}
func Test_Voidmanager_AuthenticateUser(t *testing.T) {
_, _, err := v.AuthenticateUser([]byte{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_AuthoriseNewSession(t *testing.T) {
err := v.AuthoriseNewSession([]byte{}, AuthorisationInfo{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_DeleteUser(t *testing.T) {
err := v.DeleteUser([]byte{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_GetUserInfo(t *testing.T) {
_, err := v.GetUserInfo([]byte{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_ListAllUsers(t *testing.T) {
_, err := v.ListAllUsers()
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_UploadStatus(t *testing.T) {
_, err := v.UploadStatus([]StatusUpdate{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_WriteUserInfo(t *testing.T) {
err := v.WriteUserInfo(UserInfo{})
assert.Equal(t, ErrMangerIsVoid, err)
}

View File

@ -1,20 +1,15 @@
package server package server
import ( import (
"encoding/base64" "github.com/cbeuw/Cloak/internal/server/usermanager"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/cbeuw/Cloak/internal/server/usermanager"
mux "github.com/cbeuw/Cloak/internal/multiplex" mux "github.com/cbeuw/Cloak/internal/multiplex"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const defaultUploadInterval = 1 * time.Minute
// userPanel is used to authenticate new users and book keep active users
type userPanel struct { type userPanel struct {
Manager usermanager.UserManager Manager usermanager.UserManager
@ -22,8 +17,6 @@ type userPanel struct {
activeUsers map[[16]byte]*ActiveUser activeUsers map[[16]byte]*ActiveUser
usageUpdateQueueM sync.Mutex usageUpdateQueueM sync.Mutex
usageUpdateQueue map[[16]byte]*usagePair usageUpdateQueue map[[16]byte]*usagePair
uploadInterval time.Duration
} }
func MakeUserPanel(manager usermanager.UserManager) *userPanel { func MakeUserPanel(manager usermanager.UserManager) *userPanel {
@ -31,7 +24,6 @@ func MakeUserPanel(manager usermanager.UserManager) *userPanel {
Manager: manager, Manager: manager,
activeUsers: make(map[[16]byte]*ActiveUser), activeUsers: make(map[[16]byte]*ActiveUser),
usageUpdateQueue: make(map[[16]byte]*usagePair), usageUpdateQueue: make(map[[16]byte]*usagePair),
uploadInterval: defaultUploadInterval,
} }
go ret.regularQueueUpload() go ret.regularQueueUpload()
return ret return ret
@ -40,10 +32,10 @@ func MakeUserPanel(manager usermanager.UserManager) *userPanel {
// GetBypassUser does the same as GetUser except it unconditionally creates an ActiveUser when the UID isn't already active // GetBypassUser does the same as GetUser except it unconditionally creates an ActiveUser when the UID isn't already active
func (panel *userPanel) GetBypassUser(UID []byte) (*ActiveUser, error) { func (panel *userPanel) GetBypassUser(UID []byte) (*ActiveUser, error) {
panel.activeUsersM.Lock() panel.activeUsersM.Lock()
defer panel.activeUsersM.Unlock()
var arrUID [16]byte var arrUID [16]byte
copy(arrUID[:], UID) copy(arrUID[:], UID)
if user, ok := panel.activeUsers[arrUID]; ok { if user, ok := panel.activeUsers[arrUID]; ok {
panel.activeUsersM.Unlock()
return user, nil return user, nil
} }
user := &ActiveUser{ user := &ActiveUser{
@ -54,6 +46,7 @@ func (panel *userPanel) GetBypassUser(UID []byte) (*ActiveUser, error) {
} }
copy(user.arrUID[:], UID) copy(user.arrUID[:], UID)
panel.activeUsers[user.arrUID] = user panel.activeUsers[user.arrUID] = user
panel.activeUsersM.Unlock()
return user, nil return user, nil
} }
@ -61,15 +54,16 @@ func (panel *userPanel) GetBypassUser(UID []byte) (*ActiveUser, error) {
// UID with UserInfo queried from the UserManger, should the particular UID is allowed to connect // UID with UserInfo queried from the UserManger, should the particular UID is allowed to connect
func (panel *userPanel) GetUser(UID []byte) (*ActiveUser, error) { func (panel *userPanel) GetUser(UID []byte) (*ActiveUser, error) {
panel.activeUsersM.Lock() panel.activeUsersM.Lock()
defer panel.activeUsersM.Unlock()
var arrUID [16]byte var arrUID [16]byte
copy(arrUID[:], UID) copy(arrUID[:], UID)
if user, ok := panel.activeUsers[arrUID]; ok { if user, ok := panel.activeUsers[arrUID]; ok {
panel.activeUsersM.Unlock()
return user, nil return user, nil
} }
upRate, downRate, err := panel.Manager.AuthenticateUser(UID) upRate, downRate, err := panel.Manager.AuthenticateUser(UID)
if err != nil { if err != nil {
panel.activeUsersM.Unlock()
return nil, err return nil, err
} }
valve := mux.MakeValve(upRate, downRate) valve := mux.MakeValve(upRate, downRate)
@ -81,18 +75,12 @@ func (panel *userPanel) GetUser(UID []byte) (*ActiveUser, error) {
copy(user.arrUID[:], UID) copy(user.arrUID[:], UID)
panel.activeUsers[user.arrUID] = user panel.activeUsers[user.arrUID] = user
log.WithFields(log.Fields{ panel.activeUsersM.Unlock()
"UID": base64.StdEncoding.EncodeToString(UID),
}).Info("New active user")
return user, nil return user, nil
} }
// TerminateActiveUser terminates a user and deletes its references // TerminateActiveUser terminates a user and deletes its references
func (panel *userPanel) TerminateActiveUser(user *ActiveUser, reason string) { func (panel *userPanel) TerminateActiveUser(user *ActiveUser, reason string) {
log.WithFields(log.Fields{
"UID": base64.StdEncoding.EncodeToString(user.arrUID[:]),
"reason": reason,
}).Info("Terminating active user")
panel.updateUsageQueueForOne(user) panel.updateUsageQueueForOne(user)
user.closeAllSessions(reason) user.closeAllSessions(reason)
panel.activeUsersM.Lock() panel.activeUsersM.Lock()
@ -168,9 +156,6 @@ func (panel *userPanel) commitUpdate() error {
panel.activeUsersM.RUnlock() panel.activeUsersM.RUnlock()
var numSession int var numSession int
if user != nil { if user != nil {
if user.bypass {
continue
}
numSession = user.NumSession() numSession = user.NumSession()
} }
status := usermanager.StatusUpdate{ status := usermanager.StatusUpdate{
@ -183,12 +168,7 @@ func (panel *userPanel) commitUpdate() error {
} }
statuses = append(statuses, status) statuses = append(statuses, status)
} }
panel.usageUpdateQueue = make(map[[16]byte]*usagePair)
panel.usageUpdateQueueM.Unlock()
if len(statuses) == 0 {
return nil
}
responses, err := panel.Manager.UploadStatus(statuses) responses, err := panel.Manager.UploadStatus(statuses)
if err != nil { if err != nil {
return err return err
@ -206,12 +186,14 @@ func (panel *userPanel) commitUpdate() error {
} }
} }
} }
panel.usageUpdateQueue = make(map[[16]byte]*usagePair)
panel.usageUpdateQueueM.Unlock()
return nil return nil
} }
func (panel *userPanel) regularQueueUpload() { func (panel *userPanel) regularQueueUpload() {
for { for {
time.Sleep(panel.uploadInterval) time.Sleep(1 * time.Minute)
go func() { go func() {
panel.updateUsageQueue() panel.updateUsageQueue()
err := panel.commitUpdate() err := panel.commitUpdate()

View File

@ -2,20 +2,15 @@ package server
import ( import (
"encoding/base64" "encoding/base64"
"io/ioutil" "github.com/cbeuw/Cloak/internal/server/usermanager"
"os" "os"
"testing" "testing"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/server/usermanager"
) )
func TestUserPanel_BypassUser(t *testing.T) { const MOCK_DB_NAME = "userpanel_test_mock_database.db"
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
manager, err := usermanager.MakeLocalManager(tmpDB.Name(), common.RealWorldState) func TestUserPanel_BypassUser(t *testing.T) {
manager, err := usermanager.MakeLocalManager(MOCK_DB_NAME)
if err != nil { if err != nil {
t.Error("failed to make local manager", err) t.Error("failed to make local manager", err)
} }
@ -61,130 +56,8 @@ func TestUserPanel_BypassUser(t *testing.T) {
if err != nil { if err != nil {
t.Error("failed to close localmanager", err) t.Error("failed to close localmanager", err)
} }
} err = os.Remove(MOCK_DB_NAME)
if err != nil {
var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} t.Error("failed to delete mockdb", err)
var mockWorldState = common.WorldOfTime(time.Unix(1, 0)) }
var validUserInfo = usermanager.UserInfo{
UID: mockUID,
SessionsCap: usermanager.JustInt32(10),
UpRate: usermanager.JustInt64(100),
DownRate: usermanager.JustInt64(1000),
UpCredit: usermanager.JustInt64(10000),
DownCredit: usermanager.JustInt64(100000),
ExpiryTime: usermanager.JustInt64(1000000),
}
func TestUserPanel_GetUser(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
mgr, err := usermanager.MakeLocalManager(tmpDB.Name(), mockWorldState)
if err != nil {
t.Fatal(err)
}
panel := MakeUserPanel(mgr)
t.Run("normal user", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
activeUser, err := panel.GetUser(validUserInfo.UID)
if err != nil {
t.Error(err)
}
again, err := panel.GetUser(validUserInfo.UID)
if err != nil {
t.Errorf("can't get existing user: %v", err)
}
if activeUser != again {
t.Error("got different references")
}
})
t.Run("non existent user", func(t *testing.T) {
_, err = panel.GetUser(make([]byte, 16))
if err != usermanager.ErrUserNotFound {
t.Errorf("expecting error %v, got %v", usermanager.ErrUserNotFound, err)
}
})
}
func TestUserPanel_UpdateUsageQueue(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
mgr, err := usermanager.MakeLocalManager(tmpDB.Name(), mockWorldState)
if err != nil {
t.Fatal(err)
}
panel := MakeUserPanel(mgr)
t.Run("normal update", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
user, err := panel.GetUser(validUserInfo.UID)
if err != nil {
t.Error(err)
}
user.valve.AddTx(1)
user.valve.AddRx(2)
panel.updateUsageQueue()
err = panel.commitUpdate()
if err != nil {
t.Error(err)
}
if user.valve.GetRx() != 0 || user.valve.GetTx() != 0 {
t.Error("rx and tx stats are not cleared")
}
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-1 {
t.Error("down credit incorrect update")
}
if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-2 {
t.Error("up credit incorrect update")
}
// another update
user.valve.AddTx(3)
user.valve.AddRx(4)
panel.updateUsageQueue()
err = panel.commitUpdate()
if err != nil {
t.Error(err)
}
updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID)
if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-(1+3) {
t.Error("down credit incorrect update")
}
if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-(2+4) {
t.Error("up credit incorrect update")
}
})
t.Run("terminating update", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
user, err := panel.GetUser(validUserInfo.UID)
if err != nil {
t.Error(err)
}
user.valve.AddTx(*validUserInfo.DownCredit + 100)
panel.updateUsageQueue()
err = panel.commitUpdate()
if err != nil {
t.Error(err)
}
if panel.isActive(validUserInfo.UID) {
t.Error("user not terminated")
}
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
if *updatedUinfo.DownCredit != -100 {
t.Error("down credit not updated correctly after the user has been terminated")
}
})
} }

View File

@ -1,103 +1,112 @@
package server package server
import ( import (
"bufio"
"bytes"
"crypto" "crypto"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"io" "github.com/cbeuw/Cloak/internal/ecdh"
"github.com/cbeuw/Cloak/internal/util"
"github.com/gorilla/websocket"
"net" "net"
"net/http" "net/http"
"github.com/cbeuw/Cloak/internal/common" log "github.com/sirupsen/logrus"
"github.com/cbeuw/Cloak/internal/ecdh"
) )
type WebSocket struct{} // since we need to read the first packet from the client to identify its protocol, the first packet will no longer
// be in Conn's buffer. However, websocket.Upgrade relies on reading the first packet for handshake, so we must
func (WebSocket) String() string { return "WebSocket" } // fake a conn that returns the first packet on first read
type firstBuffedConn struct {
func (WebSocket) processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) { net.Conn
var req *http.Request firstRead bool
req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(reqPacket))) firstPacket []byte
if err != nil {
err = fmt.Errorf("failed to parse first HTTP GET: %v", err)
return
}
var hiddenData []byte
hiddenData, err = base64.StdEncoding.DecodeString(req.Header.Get("hidden"))
fragments, err = WebSocket{}.unmarshalHidden(hiddenData, privateKey)
if err != nil {
err = fmt.Errorf("failed to unmarshal hidden data from WS into authFragments: %v", err)
return
}
respond = WebSocket{}.makeResponder(reqPacket, fragments.sharedSecret)
return
} }
func (WebSocket) makeResponder(reqPacket []byte, sharedSecret [32]byte) Responder { func (c *firstBuffedConn) Read(buf []byte) (int, error) {
respond := func(originalConn net.Conn, sessionKey [32]byte, randSource io.Reader) (preparedConn net.Conn, err error) { if !c.firstRead {
handler := newWsHandshakeHandler() c.firstRead = true
copy(buf, c.firstPacket)
n := len(c.firstPacket)
c.firstPacket = []byte{}
return n, nil
}
return c.Conn.Read(buf)
}
// For an explanation of the following 3 lines, see the comments in websocketAux.go type wsAcceptor struct {
http.Serve(newWsAcceptor(originalConn, reqPacket), handler) done bool
c *firstBuffedConn
}
<-handler.finished // net/http provides no method to serve an existing connection, we must feed in a net.Accept interface to get an
preparedConn = handler.conn // http.Server. This is an acceptor that accepts only one Conn
nonce := make([]byte, 12) func newWsAcceptor(conn net.Conn, first []byte) *wsAcceptor {
common.RandRead(randSource, nonce) f := make([]byte, len(first))
copy(f, first)
return &wsAcceptor{
c: &firstBuffedConn{Conn: conn, firstPacket: f},
}
}
// reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag] func (w *wsAcceptor) Accept() (net.Conn, error) {
encryptedKey, err := common.AESGCMEncrypt(nonce, sharedSecret[:], sessionKey[:]) // 32 + 16 = 48 bytes if w.done {
return nil, errors.New("already accepted")
}
w.done = true
return w.c, nil
}
func (w *wsAcceptor) Close() error {
w.done = true
return nil
}
func (w *wsAcceptor) Addr() net.Addr {
return w.c.LocalAddr()
}
type wsHandshakeHandler struct {
conn net.Conn
finished chan struct{}
}
// the handler to turn a net.Conn into a websocket.Conn
func newWsHandshakeHandler() *wsHandshakeHandler {
return &wsHandshakeHandler{finished: make(chan struct{})}
}
func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
err = fmt.Errorf("failed to encrypt reply: %v", err) log.Errorf("failed to upgrade connection to ws: %v", err)
return return
} }
reply := append(nonce, encryptedKey...) ws.conn = &util.WebSocketConn{Conn: c}
_, err = preparedConn.Write(reply) ws.finished <- struct{}{}
if err != nil {
err = fmt.Errorf("failed to write reply: %v", err)
preparedConn.Close()
return
}
return
}
return respond
} }
var ErrBadGET = errors.New("non (or malformed) HTTP GET") var ErrBadGET = errors.New("non (or malformed) HTTP GET")
func (WebSocket) unmarshalHidden(hidden []byte, staticPv crypto.PrivateKey) (fragments authFragments, err error) { func unmarshalHidden(hidden []byte, staticPv crypto.PrivateKey) (ai authenticationInfo, err error) {
if len(hidden) < 96 { if len(hidden) < 96 {
err = ErrBadGET err = ErrBadGET
return return
} }
ephPub, ok := ecdh.Unmarshal(hidden[0:32])
copy(fragments.randPubKey[:], hidden[0:32])
ephPub, ok := ecdh.Unmarshal(fragments.randPubKey[:])
if !ok { if !ok {
err = ErrInvalidPubKey err = ErrInvalidPubKey
return return
} }
var sharedSecret []byte ai.nonce = hidden[:12]
sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub)
if err != nil { ai.sharedSecret = ecdh.GenerateSharedSecret(staticPv, ephPub)
ai.ciphertextWithTag = hidden[32:]
if len(ai.ciphertextWithTag) != 64 {
err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(ai.ciphertextWithTag))
return return
} }
copy(fragments.sharedSecret[:], sharedSecret)
if len(hidden[32:]) != 64 {
err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(hidden[32:]))
return
}
copy(fragments.ciphertextWithTag[:], hidden[32:])
return return
} }

View File

@ -1,138 +0,0 @@
package server
import (
"errors"
"net"
"net/http"
"github.com/cbeuw/Cloak/internal/common"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
)
// The code in this file is mostly to obtain a binary-oriented, net.Conn analogous
// util.WebSocketConn from the awkward APIs of gorilla/websocket and net/http
//
// The flow of our process is: accept a Conn from remote, read the first packet remote sent us. If it's in the format
// of a TLS handshake, we hand it over to the TLS part; if it's in the format of a HTTP request, we process it as a
// websocket and eventually wrap the remote Conn as util.WebSocketConn,
//
// To get a util.WebSocketConn, we need a gorilla/websocket.Conn. This is obtained by using upgrader.Upgrade method
// inside a HTTP request handler function (which is defined by us). The HTTP request handler function is invoked by
// net/http package upon receiving a request from a Conn.
//
// Ideally we want to give net/http the connection we got from remote, then it can read the first packet (which should
// be an HTTP request) from that Conn and call the handler function, which can then be upgraded to obtain a
// gorilla/websocket.Conn. But this won't work for two reasons: one is that we have ALREADY READ the request packet
// from the remote Conn to determine if it's TLS or HTTP. When net/http reads from the Conn, it will not receive that
// request packet. The second reason is that there is no API in net/http that accepts a Conn at all. Instead, the
// closest we can get is http.Serve which takes in a net.Listener and a http.Handler which implements the ServeHTTP
// function.
//
// Recall that net.Listener has a method Accept which blocks until the Listener receives a connection, then
// it returns a net.Conn. net/http calls Listener.Accept repeatedly and creates a new goroutine handling each Conn
// accepted.
//
// So here is what we need to do: we need to create a type WsAcceptor that implements net.Listener interface.
// the first time WsAcceptor.Accept is called, it will return something that implements net.Conn, subsequent calls to
// Accept will return error (so that the caller won't call again)
//
// The "something that implements net.Conn" needs to do the following: the first time Read is called, it returns the
// request packet we got from the remote Conn which we have already read, so that the packet, which is an HTTP request
// will be processed by the handling function. Subsequent calls to Read will read directly from the remote Conn. To do
// this we create a type firstBuffedConn that implements net.Conn. When we instantiate a firstBuffedConn object, we
// give it the request packet we have already read from the remote Conn, as well as the reference to the remote Conn.
//
// So now we call http.Serve(WsAcceptor, [some handler]), net/http will call WsAcceptor.Accept, which returns a
// firstBuffedConn. net/http will call WsAcceptor.Accept again but this time it returns error so net/http will stop.
// firstBuffedConn.Read will then be called, which returns the request packet from remote Conn. Then
// [some handler].ServeHTTP will be called, in which websocket.upgrader.Upgrade will be called to obtain a
// websocket.Conn
//
// One problem remains: websocket.upgrader.Upgrade is called inside the handling function. The websocket.Conn it
// returned needs to be somehow preserved so we can keep using it. To do this, we define a type WsHandshakeHandler
// which implements http.Handler. WsHandshakeHandler has a struct field of type net.Conn that can be set. Inside
// WsHandshakeHandler.ServeHTTP, the returned websocket.Conn from upgrader.Upgrade will be converted into a
// util.WebSocketConn, whose reference will be kept in the struct field. Whoever has the reference to the instance of
// WsHandshakeHandler can get the reference to the established util.WebSocketConn.
//
// There is another problem: the call of http.Serve(WsAcceptor, WsHandshakeHandler) is async. We don't know when
// the instance of WsHandshakeHandler will have the util.WebSocketConn ready. We synchronise this using a channel.
// A channel called finished will be provided to an instance of WsHandshakeHandler upon its creation. Once
// WsHandshakeHandler.ServeHTTP has the reference to util.WebSocketConn ready, it will write to finished.
// Outside, immediately after the call to http.Serve(WsAcceptor, WsHandshakeHandler), we read from finished so that the
// execution will block until the reference to util.WebSocketConn is ready.
// since we need to read the first packet from the client to identify its protocol, the first packet will no longer
// be in Conn's buffer. However, websocket.Upgrade relies on reading the first packet for handshake, so we must
// fake a conn that returns the first packet on first read
type firstBuffedConn struct {
net.Conn
firstRead bool
firstPacket []byte
}
func (c *firstBuffedConn) Read(buf []byte) (int, error) {
if !c.firstRead {
c.firstRead = true
copy(buf, c.firstPacket)
n := len(c.firstPacket)
c.firstPacket = []byte{}
return n, nil
}
return c.Conn.Read(buf)
}
type wsOnceListener struct {
done bool
c *firstBuffedConn
}
// net/http provides no method to serve an existing connection, we must feed in a net.Accept interface to get an
// http.Server. This is an acceptor that accepts only one Conn
func newWsAcceptor(conn net.Conn, first []byte) *wsOnceListener {
f := make([]byte, len(first))
copy(f, first)
return &wsOnceListener{
c: &firstBuffedConn{Conn: conn, firstPacket: f},
}
}
func (w *wsOnceListener) Accept() (net.Conn, error) {
if w.done {
return nil, errors.New("already accepted")
}
w.done = true
return w.c, nil
}
func (w *wsOnceListener) Close() error {
w.done = true
return nil
}
func (w *wsOnceListener) Addr() net.Addr {
return w.c.LocalAddr()
}
type wsHandshakeHandler struct {
conn net.Conn
finished chan struct{}
}
// the handler to turn a net.Conn into a websocket.Conn
func newWsHandshakeHandler() *wsHandshakeHandler {
return &wsHandshakeHandler{finished: make(chan struct{})}
}
func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("failed to upgrade connection to ws: %v", err)
return
}
ws.conn = &common.WebSocketConn{Conn: c}
ws.finished <- struct{}{}
}

View File

@ -1,59 +0,0 @@
package server
import (
"bytes"
"testing"
"github.com/cbeuw/connutil"
)
func TestFirstBuffedConn_Read(t *testing.T) {
mockConn, writingEnd := connutil.AsyncPipe()
expectedFirstPacket := []byte{1, 2, 3}
firstBuffedConn := &firstBuffedConn{
Conn: mockConn,
firstRead: false,
firstPacket: expectedFirstPacket,
}
buf := make([]byte, 1024)
n, err := firstBuffedConn.Read(buf)
if err != nil {
t.Error(err)
return
}
if !bytes.Equal(expectedFirstPacket, buf[:n]) {
t.Error("first read doesn't produce given packet")
return
}
expectedSecondPacket := []byte{4, 5, 6, 7}
writingEnd.Write(expectedSecondPacket)
n, err = firstBuffedConn.Read(buf)
if err != nil {
t.Error(err)
return
}
if !bytes.Equal(expectedSecondPacket, buf[:n]) {
t.Error("second read doesn't produce subsequently written packet")
return
}
}
func TestWsAcceptor(t *testing.T) {
mockConn := connutil.Discard()
expectedFirstPacket := []byte{1, 2, 3}
wsAcceptor := newWsAcceptor(mockConn, expectedFirstPacket)
_, err := wsAcceptor.Accept()
if err != nil {
t.Error(err)
return
}
_, err = wsAcceptor.Accept()
if err == nil {
t.Error("accepting second time doesn't return error")
}
}

View File

@ -1,575 +0,0 @@
package test
import (
"bytes"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"math/rand"
"net"
"sync"
"testing"
"time"
"github.com/cbeuw/Cloak/internal/client"
"github.com/cbeuw/Cloak/internal/common"
mux "github.com/cbeuw/Cloak/internal/multiplex"
"github.com/cbeuw/Cloak/internal/server"
"github.com/cbeuw/connutil"
"github.com/stretchr/testify/assert"
log "github.com/sirupsen/logrus"
)
const numConns = 200 // -race option limits the number of goroutines to 8192
func serveTCPEcho(l net.Listener) {
for {
conn, err := l.Accept()
if err != nil {
log.Error(err)
return
}
go func(conn net.Conn) {
_, err := io.Copy(conn, conn)
if err != nil {
conn.Close()
log.Error(err)
return
}
}(conn)
}
}
func serveUDPEcho(listener *connutil.PipeListener) {
for {
conn, err := listener.ListenPacket("udp", "")
if err != nil {
log.Error(err)
return
}
const bufSize = 32 * 1024
go func(conn net.PacketConn) {
defer conn.Close()
buf := make([]byte, bufSize)
for {
r, _, err := conn.ReadFrom(buf)
if err != nil {
log.Error(err)
return
}
w, err := conn.WriteTo(buf[:r], nil)
if err != nil {
log.Error(err)
return
}
if r != w {
log.Error("written not eqal to read")
return
}
}
}(conn)
}
}
var bypassUID = [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var publicKey, _ = base64.StdEncoding.DecodeString("7f7TuKrs264VNSgMno8PkDlyhGhVuOSR8JHLE6H4Ljc=")
var privateKey, _ = base64.StdEncoding.DecodeString("SMWeC6VuZF8S/id65VuFQFlfa7hTEJBpL6wWhqPP100=")
var basicUDPConfig = client.RawConfig{
ServerName: "www.example.com",
ProxyMethod: "openvpn",
EncryptionMethod: "plain",
UID: bypassUID[:],
PublicKey: publicKey,
NumConn: 4,
UDP: true,
Transport: "direct",
RemoteHost: "fake.com",
RemotePort: "9999",
LocalHost: "127.0.0.1",
LocalPort: "9999",
}
var basicTCPConfig = client.RawConfig{
ServerName: "www.example.com",
ProxyMethod: "shadowsocks",
EncryptionMethod: "plain",
UID: bypassUID[:],
PublicKey: publicKey,
NumConn: 4,
UDP: false,
Transport: "direct",
RemoteHost: "fake.com",
RemotePort: "9999",
LocalHost: "127.0.0.1",
LocalPort: "9999",
BrowserSig: "firefox",
}
var singleplexTCPConfig = client.RawConfig{
ServerName: "www.example.com",
ProxyMethod: "shadowsocks",
EncryptionMethod: "plain",
UID: bypassUID[:],
PublicKey: publicKey,
NumConn: 0,
UDP: false,
Transport: "direct",
RemoteHost: "fake.com",
RemotePort: "9999",
LocalHost: "127.0.0.1",
LocalPort: "9999",
BrowserSig: "safari",
}
func generateClientConfigs(rawConfig client.RawConfig, state common.WorldState) (client.LocalConnConfig, client.RemoteConnConfig, client.AuthInfo) {
lcl, rmt, auth, err := rawConfig.ProcessRawConfig(state)
if err != nil {
log.Fatal(err)
}
return lcl, rmt, auth
}
func basicServerState(ws common.WorldState) *server.State {
var serverConfig = server.RawConfig{
ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}},
BindAddr: []string{"fake.com:9999"},
BypassUID: [][]byte{bypassUID[:]},
RedirAddr: "fake.com:9999",
PrivateKey: privateKey,
KeepAlive: 15,
CncMode: false,
}
state, err := server.InitState(serverConfig, ws)
if err != nil {
log.Fatal(err)
}
return state
}
type mockUDPDialer struct {
addrCh chan *net.UDPAddr
raddr *net.UDPAddr
}
func (m *mockUDPDialer) Dial(network, address string) (net.Conn, error) {
if m.raddr == nil {
m.raddr = <-m.addrCh
}
return net.DialUDP("udp", nil, m.raddr)
}
func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, ai client.AuthInfo, serverState *server.State) (common.Dialer, *connutil.PipeListener, common.Dialer, net.Listener, error) {
// redirecting web server
// ^
// |
// |
// redirFromCkServerL
// |
// |
// proxy client ----proxyToCkClientD----> ck-client ------> ck-server ----proxyFromCkServerL----> proxy server
// ^
// |
// |
// netToCkServerD
// |
// |
// whatever connection initiator (including a proper ck-client)
netToCkServerD, ckServerListener := connutil.DialerListener(10 * 1024)
clientSeshMaker := func() *mux.Session {
ai := ai
quad := make([]byte, 4)
common.RandRead(ai.WorldState.Rand, quad)
ai.SessionId = binary.BigEndian.Uint32(quad)
return client.MakeSession(rcc, ai, netToCkServerD)
}
var proxyToCkClientD common.Dialer
if ai.Unordered {
// We can only "dial" a single UDP connection as we can't send packets from different context
// to a single UDP listener
addrCh := make(chan *net.UDPAddr, 1)
mDialer := &mockUDPDialer{
addrCh: addrCh,
}
acceptor := func() (*net.UDPConn, error) {
laddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, err := net.ListenUDP("udp", laddr)
addrCh <- conn.LocalAddr().(*net.UDPAddr)
return conn, err
}
go client.RouteUDP(acceptor, lcc.Timeout, rcc.Singleplex, clientSeshMaker)
proxyToCkClientD = mDialer
} else {
var proxyToCkClientL *connutil.PipeListener
proxyToCkClientD, proxyToCkClientL = connutil.DialerListener(10 * 1024)
go client.RouteTCP(proxyToCkClientL, lcc.Timeout, rcc.Singleplex, clientSeshMaker)
}
// set up server
ckServerToProxyD, proxyFromCkServerL := connutil.DialerListener(10 * 1024)
ckServerToWebD, redirFromCkServerL := connutil.DialerListener(10 * 1024)
serverState.ProxyDialer = ckServerToProxyD
serverState.RedirDialer = ckServerToWebD
go server.Serve(ckServerListener, serverState)
return proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, nil
}
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
var wg sync.WaitGroup
for _, conn := range conns {
wg.Add(1)
go func(conn net.Conn) {
defer wg.Done()
testData := make([]byte, msgLen)
rand.Read(testData)
// we cannot call t.Fatalf in concurrent contexts
n, err := conn.Write(testData)
if n != msgLen {
t.Errorf("written only %v, err %v", n, err)
return
}
recvBuf := make([]byte, msgLen)
_, err = io.ReadFull(conn, recvBuf)
if err != nil {
t.Errorf("failed to read back: %v", err)
return
}
if !bytes.Equal(testData, recvBuf) {
t.Errorf("echoed data not correct")
return
}
}(conn)
}
wg.Wait()
}
func TestUDP(t *testing.T) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState)
sta := basicServerState(worldState)
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
t.Run("simple send", func(t *testing.T) {
pxyClientConn, err := proxyToCkClientD.Dial("udp", "")
if err != nil {
t.Error(err)
}
const testDataLen = 1500
testData := make([]byte, testDataLen)
rand.Read(testData)
n, err := pxyClientConn.Write(testData)
if n != testDataLen {
t.Errorf("wrong length sent: %v", n)
}
if err != nil {
t.Error(err)
}
pxyServerConn, err := proxyFromCkServerL.ListenPacket("", "")
if err != nil {
t.Error(err)
}
recvBuf := make([]byte, testDataLen+100)
r, _, err := pxyServerConn.ReadFrom(recvBuf)
if err != nil {
t.Error(err)
}
if !bytes.Equal(testData, recvBuf[:r]) {
t.Error("read wrong data")
}
})
const echoMsgLen = 1024
t.Run("user echo", func(t *testing.T) {
go serveUDPEcho(proxyFromCkServerL)
var conn [1]net.Conn
conn[0], err = proxyToCkClientD.Dial("udp", "")
if err != nil {
t.Error(err)
}
runEchoTest(t, conn[:], echoMsgLen)
})
}
func TestTCPSingleplex(t *testing.T) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(singleplexTCPConfig, worldState)
sta := basicServerState(worldState)
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
const echoMsgLen = 1 << 16
go serveTCPEcho(proxyFromCkServerL)
proxyConn1, err := proxyToCkClientD.Dial("", "")
if err != nil {
t.Fatal(err)
}
runEchoTest(t, []net.Conn{proxyConn1}, echoMsgLen)
user, err := sta.Panel.GetUser(ai.UID[:])
if err != nil {
t.Fatalf("failed to fetch user: %v", err)
}
if user.NumSession() != 1 {
t.Error("no session were made on first connection establishment")
}
proxyConn2, err := proxyToCkClientD.Dial("", "")
if err != nil {
t.Fatal(err)
}
runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen)
if user.NumSession() != 2 {
t.Error("no extra session were made on second connection establishment")
}
// Both conns should work
runEchoTest(t, []net.Conn{proxyConn1, proxyConn2}, echoMsgLen)
proxyConn1.Close()
assert.Eventually(t, func() bool {
return user.NumSession() == 1
}, time.Second, 10*time.Millisecond, "first session was not closed on connection close")
// conn2 should still work
runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen)
var conns [numConns]net.Conn
for i := 0; i < numConns; i++ {
conns[i], err = proxyToCkClientD.Dial("", "")
if err != nil {
t.Fatal(err)
}
}
runEchoTest(t, conns[:], echoMsgLen)
}
func TestTCPMultiplex(t *testing.T) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
sta := basicServerState(worldState)
proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
t.Run("user echo single", func(t *testing.T) {
for i := 0; i < 18; i += 2 {
dataLen := 1 << i
writeData := make([]byte, dataLen)
rand.Read(writeData)
t.Run(fmt.Sprintf("data length %v", dataLen), func(t *testing.T) {
go serveTCPEcho(proxyFromCkServerL)
conn, err := proxyToCkClientD.Dial("", "")
if err != nil {
t.Error(err)
}
n, err := conn.Write(writeData)
if err != nil {
t.Error(err)
}
if n != dataLen {
t.Errorf("write length doesn't match up: %v, expected %v", n, dataLen)
}
recvBuf := make([]byte, dataLen)
_, err = io.ReadFull(conn, recvBuf)
if err != nil {
t.Error(err)
}
if !bytes.Equal(writeData, recvBuf) {
t.Error("echoed data incorrect")
}
})
}
})
const echoMsgLen = 16384
t.Run("user echo", func(t *testing.T) {
go serveTCPEcho(proxyFromCkServerL)
var conns [numConns]net.Conn
for i := 0; i < numConns; i++ {
conns[i], err = proxyToCkClientD.Dial("", "")
if err != nil {
t.Error(err)
}
}
runEchoTest(t, conns[:], echoMsgLen)
})
t.Run("redir echo", func(t *testing.T) {
go serveTCPEcho(redirFromCkServerL)
var conns [numConns]net.Conn
for i := 0; i < numConns; i++ {
conns[i], err = netToCkServerD.Dial("", "")
if err != nil {
t.Error(err)
}
}
runEchoTest(t, conns[:], echoMsgLen)
})
}
func TestClosingStreamsFromProxy(t *testing.T) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
for clientConfigName, clientConfig := range map[string]client.RawConfig{"basic": basicTCPConfig, "singleplex": singleplexTCPConfig} {
clientConfig := clientConfig
clientConfigName := clientConfigName
t.Run(clientConfigName, func(t *testing.T) {
lcc, rcc, ai := generateClientConfigs(clientConfig, worldState)
sta := basicServerState(worldState)
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
t.Run("closing from server", func(t *testing.T) {
clientConn, _ := proxyToCkClientD.Dial("", "")
clientConn.Write(make([]byte, 16))
serverConn, _ := proxyFromCkServerL.Accept()
serverConn.Close()
assert.Eventually(t, func() bool {
_, err := clientConn.Read(make([]byte, 16))
return err != nil
}, time.Second, 10*time.Millisecond, "closing stream on server side is not reflected to the client")
})
t.Run("closing from client", func(t *testing.T) {
// closing stream on client side
clientConn, _ := proxyToCkClientD.Dial("", "")
clientConn.Write(make([]byte, 16))
serverConn, _ := proxyFromCkServerL.Accept()
clientConn.Close()
assert.Eventually(t, func() bool {
_, err := serverConn.Read(make([]byte, 16))
return err != nil
}, time.Second, 10*time.Millisecond, "closing stream on client side is not reflected to the server")
})
t.Run("send then close", func(t *testing.T) {
testData := make([]byte, 24*1024)
rand.Read(testData)
clientConn, _ := proxyToCkClientD.Dial("", "")
go func() {
clientConn.Write(testData)
// it takes time for this written data to be copied asynchronously
// into ck-server's domain. If the pipe is closed before that, read
// by ck-client in RouteTCP will fail as we have closed it.
time.Sleep(700 * time.Millisecond)
clientConn.Close()
}()
readBuf := make([]byte, len(testData))
serverConn, err := proxyFromCkServerL.Accept()
if err != nil {
t.Errorf("failed to accept a connection delievering data sent before closing: %v", err)
}
_, err = io.ReadFull(serverConn, readBuf)
if err != nil {
t.Errorf("failed to read data sent before closing: %v", err)
}
})
})
}
}
func BenchmarkIntegration(b *testing.B) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
sta := basicServerState(worldState)
const bufSize = 16 * 1024
encryptionMethods := map[string]byte{
"plain": mux.EncryptionMethodPlain,
"chacha20-poly1305": mux.EncryptionMethodChaha20Poly1305,
"aes-256-gcm": mux.EncryptionMethodAES256GCM,
"aes-128-gcm": mux.EncryptionMethodAES128GCM,
}
for name, method := range encryptionMethods {
b.Run(name, func(b *testing.B) {
ai.EncryptionMethod = method
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
b.Fatal(err)
}
b.Run("single stream bandwidth", func(b *testing.B) {
more := make(chan int, 10)
go func() {
// sender
writeBuf := make([]byte, bufSize+100)
serverConn, _ := proxyFromCkServerL.Accept()
for {
serverConn.Write(writeBuf)
<-more
}
}()
// receiver
clientConn, _ := proxyToCkClientD.Dial("", "")
readBuf := make([]byte, bufSize)
clientConn.Write([]byte{1}) // to make server accept
b.SetBytes(bufSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
io.ReadFull(clientConn, readBuf)
// ask for more
more <- 0
}
})
b.Run("single stream latency", func(b *testing.B) {
clientConn, _ := proxyToCkClientD.Dial("", "")
buf := []byte{1}
clientConn.Write(buf)
serverConn, _ := proxyFromCkServerL.Accept()
serverConn.Read(buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
clientConn.Write(buf)
serverConn.Read(buf)
}
})
})
}
}

View File

@ -1 +0,0 @@
package test

116
internal/util/util.go Normal file
View File

@ -0,0 +1,116 @@
package util
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"errors"
"io"
"net"
"strconv"
"time"
)
func AESGCMEncrypt(nonce []byte, key []byte, plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return aesgcm.Seal(nil, nonce, plaintext, nil), nil
}
func AESGCMDecrypt(nonce []byte, key []byte, ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
plain, err := aesgcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plain, nil
}
// ReadTLS reads TLS data according to its record layer
func ReadTLS(conn net.Conn, buffer []byte) (n int, err error) {
// TCP is a stream. Multiple TLS messages can arrive at the same time,
// a single message can also be segmented due to MTU of the IP layer.
// This function guareentees a single TLS message to be read and everything
// else is left in the buffer.
i, err := io.ReadFull(conn, buffer[:5])
if err != nil {
return
}
dataLength := int(binary.BigEndian.Uint16(buffer[3:5]))
if dataLength > len(buffer) {
err = errors.New("Reading TLS message: message size greater than buffer. message size: " + strconv.Itoa(dataLength))
return
}
left := dataLength
readPtr := 5
for left != 0 {
// If left > buffer size (i.e. our message got segmented), the entire MTU is read
// if left = buffer size, the entire buffer is all there left to read
// if left < buffer size (i.e. multiple messages came together),
// only the message we want is read
i, err = conn.Read(buffer[readPtr : readPtr+left])
if err != nil {
return
}
left -= i
readPtr += i
}
n = 5 + dataLength
return
}
// AddRecordLayer adds record layer to data
func AddRecordLayer(input []byte, typ []byte, ver []byte) []byte {
length := make([]byte, 2)
binary.BigEndian.PutUint16(length, uint16(len(input)))
ret := make([]byte, 5+len(input))
copy(ret[0:1], typ)
copy(ret[1:3], ver)
copy(ret[3:5], length)
copy(ret[5:], input)
return ret
}
func Pipe(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) {
// The maximum size of TLS message will be 16380+14+16. 14 because of the stream header and 16
// because of the salt/mac
// 16408 is the max TLS message size on Firefox
buf := make([]byte, 16378)
if srcReadTimeout != 0 {
src.SetReadDeadline(time.Now().Add(srcReadTimeout))
}
for {
if srcReadTimeout != 0 {
src.SetReadDeadline(time.Now().Add(srcReadTimeout))
}
i, err := io.ReadAtLeast(src, buf, 1)
if err != nil {
dst.Close()
src.Close()
return
}
i, err = dst.Write(buf[:i])
if err != nil {
dst.Close()
src.Close()
return
}
}
}

View File

@ -0,0 +1,26 @@
package util
import (
"io"
"io/ioutil"
"math/rand"
"testing"
)
func BenchmarkPipe(b *testing.B) {
reader := rand.New(rand.NewSource(42))
buf := make([]byte, 16380)
for i := 0; i < b.N; i++ {
n, err := io.ReadAtLeast(reader, buf, 1)
if err != nil {
b.Error(err)
return
}
n, err = ioutil.Discard.Write(buf[:n])
if err != nil {
b.Error(err)
return
}
b.SetBytes(int64(n))
}
}

View File

@ -1,16 +1,14 @@
package common package util
import ( import (
"errors" "errors"
"github.com/gorilla/websocket"
"io" "io"
"net"
"sync" "sync"
"time" "time"
"github.com/gorilla/websocket"
) )
// WebSocketConn implements io.ReadWriteCloser
// it makes websocket.Conn binary-oriented
type WebSocketConn struct { type WebSocketConn struct {
*websocket.Conn *websocket.Conn
writeM sync.Mutex writeM sync.Mutex
@ -75,3 +73,8 @@ func (ws *WebSocketConn) SetDeadline(t time.Time) error {
} }
return nil return nil
} }
// ws unit reader
func ReadWebSocket(conn net.Conn, buffer []byte) (n int, err error) {
return conn.Read(buffer)
}

View File

@ -1,12 +1,15 @@
#!/usr/bin/env bash go get github.com/mitchellh/gox
set -eu
go install github.com/mitchellh/gox@latest
mkdir -p release mkdir -p release
rm -f ./release/* read -p "Cleaning $PWD/release directory. Proceed? [y/n]" res
if [ ! "$res" == "y" ]; then
echo "Abort"
exit 1
fi
rm -rf ./release/*
if [ -z "$v" ]; then if [ -z "$v" ]; then
echo "Version number cannot be null. Run with v=[version] release.sh" echo "Version number cannot be null. Run with v=[version] release.sh"
@ -14,24 +17,20 @@ if [ -z "$v" ]; then
fi fi
output="{{.Dir}}-{{.OS}}-{{.Arch}}-$v" output="{{.Dir}}-{{.OS}}-{{.Arch}}-$v"
osarch="!darwin/arm !darwin/386" osarch="!darwin/arm !darwin/arm64"
echo "Compiling:" echo "Compiling:"
os="windows linux darwin" os="windows linux darwin"
arch="amd64 386 arm arm64 mips mips64 mipsle mips64le" arch="amd64 386 arm arm64 mips mips64 mipsle mips64le"
pushd cmd/ck-client pushd cmd/ck-client
CGO_ENABLED=0 gox -ldflags "-X main.version=${v}" -os="$os" -arch="$arch" -osarch="$osarch" -output="$output" gox -ldflags "-X main.version=${v}" -os="$os" -arch="$arch" -osarch="$osarch" -output="$output"
CGO_ENABLED=0 GOOS="linux" GOARCH="mips" GOMIPS="softfloat" go build -ldflags "-X main.version=${v}" -o ck-client-linux-mips_softfloat-"${v}" GOOS="linux" GOARCH="mips" GOMIPS="softfloat" go build -ldflags "-X main.version=${v}" -o ck-client-linux-mips_softfloat-${v}
CGO_ENABLED=0 GOOS="linux" GOARCH="mipsle" GOMIPS="softfloat" go build -ldflags "-X main.version=${v}" -o ck-client-linux-mipsle_softfloat-"${v}" GOOS="linux" GOARCH="mipsle" GOMIPS="softfloat" go build -ldflags "-X main.version=${v}" -o ck-client-linux-mipsle_softfloat-${v}
mv ck-client-* ../../release mv ck-client-* ../../release
popd
os="linux" os="linux"
arch="amd64 386 arm arm64" arch="amd64 386 arm arm64"
pushd cmd/ck-server pushd ../ck-server
CGO_ENABLED=0 gox -ldflags "-X main.version=${v}" -os="$os" -arch="$arch" -osarch="$osarch" -output="$output" gox -ldflags "-X main.version=${v}" -os="$os" -arch="$arch" -osarch="$osarch" -output="$output"
mv ck-server-* ../../release mv ck-server-* ../../release
popd
sha256sum release/*

View File

@ -1,13 +0,0 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"config:recommended"
],
"packageRules": [
{
"packagePatterns": ["*"],
"excludePackagePatterns": ["utls"],
"enabled": false
}
]
}